Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
00d76542
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
11 个月 前同步成功
通知
204
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
00d76542
编写于
9月 28, 2021
作者:
J
Jackwaterveg
提交者:
GitHub
9月 28, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #864 from PaddlePaddle/collator
refactor st and asr collator
上级
32aaf403
f628e218
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
322 addition
and
861 deletion
+322
-861
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+6
-17
deepspeech/frontend/audio.py
deepspeech/frontend/audio.py
+13
-7
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+29
-48
deepspeech/frontend/speech.py
deepspeech/frontend/speech.py
+8
-3
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+46
-0
deepspeech/io/collator.py
deepspeech/io/collator.py
+204
-155
deepspeech/io/collator_st.py
deepspeech/io/collator_st.py
+0
-631
deepspeech/io/reader.py
deepspeech/io/reader.py
+16
-0
未找到文件。
deepspeech/exps/u2_st/model.py
浏览文件 @
00d76542
...
@@ -28,10 +28,8 @@ from paddle import distributed as dist
...
@@ -28,10 +28,8 @@ from paddle import distributed as dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
deepspeech.io.collator_st
import
KaldiPrePorocessedCollator
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.collator_st
import
SpeechCollator
from
deepspeech.io.collator
import
TripletSpeechCollator
from
deepspeech.io.collator_st
import
TripletKaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
TripletSpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
TripletManifestDataset
from
deepspeech.io.dataset
import
TripletManifestDataset
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
...
@@ -258,22 +256,13 @@ class U2STTrainer(Trainer):
...
@@ -258,22 +256,13 @@ class U2STTrainer(Trainer):
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
dev_dataset
=
Dataset
.
from_config
(
config
)
dev_dataset
=
Dataset
.
from_config
(
config
)
if
config
.
collator
.
raw_wav
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
Collator
=
TripletSpeechCollator
Collator
=
TripletSpeechCollator
TestCollator
=
SpeechCollator
TestCollator
=
SpeechCollator
else
:
TestCollator
=
Collator
=
SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else
:
else
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
TestCollator
=
Collator
=
SpeechCollator
Collator
=
TripletKaldiPrePorocessedCollator
TestCollator
=
KaldiPrePorocessedCollator
else
:
TestCollator
=
Collator
=
KaldiPrePorocessedCollator
collate_fn_train
=
Collator
.
from_config
(
config
)
collate_fn_train
=
Collator
.
from_config
(
config
)
config
.
collator
.
augmentation_config
=
""
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
Collator
.
from_config
(
config
)
collate_fn_dev
=
Collator
.
from_config
(
config
)
...
...
deepspeech/frontend/audio.py
浏览文件 @
00d76542
...
@@ -24,8 +24,10 @@ import soundfile
...
@@ -24,8 +24,10 @@ import soundfile
import
soxbindings
as
sox
import
soxbindings
as
sox
from
scipy
import
signal
from
scipy
import
signal
from
.utility
import
subfile_from_tar
class
AudioSegment
(
object
):
class
AudioSegment
():
"""Monaural audio segment abstraction.
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
:param samples: Audio samples [num_samples x num_channels].
...
@@ -68,16 +70,20 @@ class AudioSegment(object):
...
@@ -68,16 +70,20 @@ class AudioSegment(object):
self
.
duration
,
self
.
rms_db
))
self
.
duration
,
self
.
rms_db
))
@
classmethod
@
classmethod
def
from_file
(
cls
,
file
):
def
from_file
(
cls
,
file
,
infos
=
None
):
"""Create audio segment from audio file.
"""Create audio segment from audio file.
:param filepath: Filepath or file object to audio file.
Args:
:type filepath: str|file
filepath (str|file): Filepath or file object to audio file.
:return: Audio segment instance.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
:rtype: AudioSegment
Returns:
AudioSegment: Audio segment instance.
"""
"""
if
isinstance
(
file
,
str
)
and
re
.
findall
(
r
".seqbin_\d+$"
,
file
):
if
isinstance
(
file
,
str
)
and
re
.
findall
(
r
".seqbin_\d+$"
,
file
):
return
cls
.
from_sequence_file
(
file
)
return
cls
.
from_sequence_file
(
file
)
elif
isinstance
(
file
,
str
)
and
file
.
startswith
(
'tar:'
):
return
cls
.
from_file
(
subfile_from_tar
(
file
,
infos
))
else
:
else
:
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
return
cls
(
samples
,
sample_rate
)
...
...
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
00d76542
...
@@ -64,8 +64,12 @@ class SpeechFeaturizer():
...
@@ -64,8 +64,12 @@ class SpeechFeaturizer():
target_sample_rate
=
16000
,
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
use_dB_normalization
=
True
,
target_dB
=-
20
,
target_dB
=-
20
,
dither
=
1.0
):
dither
=
1.0
,
self
.
_audio_featurizer
=
AudioFeaturizer
(
maskctc
=
False
):
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
audio_feature
=
AudioFeaturizer
(
specgram_type
=
specgram_type
,
specgram_type
=
specgram_type
,
feat_dim
=
feat_dim
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
delta_delta
=
delta_delta
,
...
@@ -77,8 +81,12 @@ class SpeechFeaturizer():
...
@@ -77,8 +81,12 @@ class SpeechFeaturizer():
use_dB_normalization
=
use_dB_normalization
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
,
target_dB
=
target_dB
,
dither
=
dither
)
dither
=
dither
)
self
.
_text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
,
spm_model_prefix
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
maskctc
=
maskctc
)
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
"""Extract features for speech segment.
"""Extract features for speech segment.
...
@@ -94,60 +102,33 @@ class SpeechFeaturizer():
...
@@ -94,60 +102,33 @@ class SpeechFeaturizer():
Returns:
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
"""
spec_feature
=
self
.
_audio_featurizer
.
featurize
(
speech_segment
)
spec_feature
=
self
.
audio_feature
.
featurize
(
speech_segment
)
if
keep_transcription_text
:
if
keep_transcription_text
:
return
spec_feature
,
speech_segment
.
transcript
return
spec_feature
,
speech_segment
.
transcript
if
speech_segment
.
has_token
:
if
speech_segment
.
has_token
:
text_ids
=
speech_segment
.
token_ids
text_ids
=
speech_segment
.
token_ids
else
:
else
:
text_ids
=
self
.
_text_featurizer
.
featurize
(
text_ids
=
self
.
text_feature
.
featurize
(
speech_segment
.
transcript
)
speech_segment
.
transcript
)
return
spec_feature
,
text_ids
return
spec_feature
,
text_ids
@
property
def
text_featurize
(
self
,
text
,
keep_transcription_text
):
def
vocab_size
(
self
):
"""Extract features for speech segment.
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return
self
.
_text_featurizer
.
vocab_list
@
property
1. For audio parts, extract the audio features.
def
vocab_dict
(
self
):
2. For transcript parts, keep the original text or convert text string
"""Return the vocabulary in dict.
to a list of token indices in char-level.
Returns:
Dict[str, int]:
"""
return
self
.
_text_featurizer
.
vocab_dict
@
property
Args:
def
feature_size
(
self
):
text (str): text.
"""Return the audio feature size.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
int: audio feature size.
"""
return
self
.
_audio_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
"""time length in `ms` unit per frame
Returns:
Returns:
float: time(ms)/frame
(str|List[int]): text, or list of token indices.
"""
"""
return
self
.
_audio_featurizer
.
stride_ms
if
keep_transcription_text
:
return
text
@
property
text_ids
=
self
.
text_feature
.
featurize
(
text
)
def
text_feature
(
self
):
return
text_ids
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
return
self
.
_text_featurizer
deepspeech/frontend/speech.py
浏览文件 @
00d76542
...
@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment):
...
@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment):
return
not
self
.
__eq__
(
other
)
return
not
self
.
__eq__
(
other
)
@
classmethod
@
classmethod
def
from_file
(
cls
,
filepath
,
transcript
,
tokens
=
None
,
token_ids
=
None
):
def
from_file
(
cls
,
filepath
,
transcript
,
tokens
=
None
,
token_ids
=
None
,
infos
=
None
):
"""Create speech segment from audio file and corresponding transcript.
"""Create speech segment from audio file and corresponding transcript.
Args:
Args:
...
@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment):
...
@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment):
transcript (str): Transcript text for the speech.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
Returns:
SpeechSegment: Speech segment instance.
SpeechSegment: Speech segment instance.
"""
"""
audio
=
AudioSegment
.
from_file
(
filepath
,
infos
)
audio
=
AudioSegment
.
from_file
(
filepath
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
,
tokens
,
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
,
tokens
,
token_ids
)
token_ids
)
...
...
deepspeech/frontend/utility.py
浏览文件 @
00d76542
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Contains data helper functions."""
"""Contains data helper functions."""
import
json
import
json
import
math
import
math
import
tarfile
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Text
from
typing
import
Text
...
@@ -112,6 +113,51 @@ def read_manifest(
...
@@ -112,6 +113,51 @@ def read_manifest(
return
manifest
return
manifest
# Tar File read
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
def
parse_tar
(
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
subfile_from_tar
(
file
,
local_data
=
None
):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
local_data
is
None
:
local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
assert
isinstance
(
local_data
,
TarLocalData
)
if
'tar2info'
not
in
local_data
.
__dict__
:
local_data
.
tar2info
=
{}
if
'tar2object'
not
in
local_data
.
__dict__
:
local_data
.
tar2object
=
{}
if
tarpath
not
in
local_data
.
tar2info
:
fobj
,
infos
=
parse_tar
(
tarpath
)
local_data
.
tar2info
[
tarpath
]
=
infos
local_data
.
tar2object
[
tarpath
]
=
fobj
else
:
fobj
=
local_data
.
tar2object
[
tarpath
]
infos
=
local_data
.
tar2info
[
tarpath
]
return
fobj
.
extractfile
(
infos
[
filename
])
def
rms_to_db
(
rms
:
float
):
def
rms_to_db
(
rms
:
float
):
"""Root Mean Square to dB.
"""Root Mean Square to dB.
...
...
deepspeech/io/collator.py
浏览文件 @
00d76542
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
io
import
io
from
collections
import
namedtuple
from
typing
import
Optional
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
...
@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
TarLocalData
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.io.utility
import
pad_list
from
deepspeech.io.utility
import
pad_list
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
__all__
=
[
"SpeechCollator"
]
__all__
=
[
"SpeechCollator"
,
"TripletSpeechCollator"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
# namedtupe need global for pickle.
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
class
SpeechCollator
():
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
mean_std_filepath
=
""
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
specgram_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
# feature dither
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'mean_std_filepath'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
'specgram_type'
in
config
.
collator
assert
'n_fft'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
mean_std_filepath
=
config
.
collator
.
mean_std_filepath
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
specgram_type
=
config
.
collator
.
specgram_type
,
feat_dim
=
config
.
collator
.
feat_dim
,
delta_delta
=
config
.
collator
.
delta_delta
,
stride_ms
=
config
.
collator
.
stride_ms
,
window_ms
=
config
.
collator
.
window_ms
,
n_fft
=
config
.
collator
.
n_fft
,
max_freq
=
config
.
collator
.
max_freq
,
target_sample_rate
=
config
.
collator
.
target_sample_rate
,
use_dB_normalization
=
config
.
collator
.
use_dB_normalization
,
target_dB
=
config
.
collator
.
target_dB
,
dither
=
config
.
collator
.
dither
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
class
SpeechCollatorBase
():
def
__init__
(
def
__init__
(
self
,
self
,
aug_file
,
aug_file
,
...
@@ -121,7 +41,7 @@ class SpeechCollator():
...
@@ -121,7 +41,7 @@ class SpeechCollator():
spm_model_prefix
,
spm_model_prefix
,
random_seed
=
0
,
random_seed
=
0
,
unit_type
=
"char"
,
unit_type
=
"char"
,
spec
gra
m_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
spec
tru
m_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
stride_ms
=
10.0
,
# ms
...
@@ -146,7 +66,7 @@ class SpeechCollator():
...
@@ -146,7 +66,7 @@ class SpeechCollator():
n_fft (int, optional): fft points for rfft. Defaults to None.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
spec
gra
m_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
spec
tru
m_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
...
@@ -159,23 +79,27 @@ class SpeechCollator():
...
@@ -159,23 +79,27 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
a user-defined shape) within one batch.
"""
"""
self
.
_keep_transcription_text
=
keep_transcription_text
self
.
keep_transcription_text
=
keep_transcription_text
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
feat_dim
=
feat_dim
self
.
loader
=
LoadInputsAndTargets
()
# only for tar filetype
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
self
.
augmentation
=
AugmentationPipeline
(
augmentation_config
=
aug_file
.
read
(),
random_seed
=
random_seed
)
augmentation_config
=
aug_file
.
read
(),
random_seed
=
random_seed
)
self
.
_normalizer
=
FeatureNormalizer
(
self
.
_normalizer
=
FeatureNormalizer
(
mean_std_filepath
)
if
mean_std_filepath
else
None
mean_std_filepath
)
if
mean_std_filepath
else
None
self
.
_stride_ms
=
stride_ms
self
.
_target_sample_rate
=
target_sample_rate
self
.
_speech_featurizer
=
SpeechFeaturizer
(
self
.
_speech_featurizer
=
SpeechFeaturizer
(
unit_type
=
unit_type
,
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
spm_model_prefix
=
spm_model_prefix
,
spec
gram_type
=
specgra
m_type
,
spec
trum_type
=
spectru
m_type
,
feat_dim
=
feat_dim
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
delta_delta
=
delta_delta
,
stride_ms
=
stride_ms
,
stride_ms
=
stride_ms
,
...
@@ -187,33 +111,11 @@ class SpeechCollator():
...
@@ -187,33 +111,11 @@ class SpeechCollator():
target_dB
=
target_dB
,
target_dB
=
target_dB
,
dither
=
dither
)
dither
=
dither
)
def
_parse_tar
(
self
,
file
):
self
.
feature_size
=
self
.
_speech_featurizer
.
audio_feature
.
feature_size
"""Parse a tar file to get a tarfile object
self
.
text_feature
=
self
.
_speech_featurizer
.
text_feature
and a map containing tarinfoes
self
.
vocab_dict
=
self
.
text_feature
.
vocab_dict
"""
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
result
=
{}
self
.
vocab_size
=
self
.
text_feature
.
vocab_size
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
_subfile_from_tar
(
self
,
file
):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
'tar2info'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2info
=
{}
if
'tar2object'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2object
=
{}
if
tarpath
not
in
self
.
_local_data
.
tar2info
:
object
,
infoes
=
self
.
_parse_tar
(
tarpath
)
self
.
_local_data
.
tar2info
[
tarpath
]
=
infoes
self
.
_local_data
.
tar2object
[
tarpath
]
=
object
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
def
process_utterance
(
self
,
audio_file
,
transcript
):
def
process_utterance
(
self
,
audio_file
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
"""Load, augment, featurize and normalize for speech data.
...
@@ -226,23 +128,36 @@ class SpeechCollator():
...
@@ -226,23 +128,36 @@ class SpeechCollator():
where transcription part could be token ids or text.
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
:rtype: tuple of (2darray, list)
"""
"""
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
filetype
=
self
.
loader
.
file_type
(
audio_file
)
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
if
filetype
!=
'sound'
:
spectrum
=
self
.
loader
.
_get_from_loader
(
audio_file
,
filetype
)
feat_dim
=
spectrum
.
shape
[
1
]
assert
feat_dim
==
self
.
feat_dim
,
f
"expect feat dim
{
self
.
feat_dim
}
, but got
{
feat_dim
}
"
if
self
.
keep_transcription_text
:
transcript_part
=
transcript
else
:
text_ids
=
self
.
text_feature
.
featurize
(
transcript
)
transcript_part
=
text_ids
else
:
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
# read audio
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
,
infos
=
self
.
_local_data
)
# audio augment
self
.
augmentation
.
transform_audio
(
speech_segment
)
# audio augment
# extract speech feature
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
spectrum
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
keep_transcription_text
)
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
# CMVN spectrum
speech_segment
,
self
.
_keep_transcription_text
)
if
self
.
_normalizer
:
if
self
.
_normalizer
:
spectrum
=
self
.
_normalizer
.
apply
(
spectrum
)
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
# spec
gra
m augment
# spec
tru
m augment
spec
gram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgra
m
)
spec
trum
=
self
.
augmentation
.
transform_feature
(
spectru
m
)
return
spec
gra
m
,
transcript_part
return
spec
tru
m
,
transcript_part
def
__call__
(
self
,
batch
):
def
__call__
(
self
,
batch
):
"""batch examples
"""batch examples
...
@@ -272,16 +187,14 @@ class SpeechCollator():
...
@@ -272,16 +187,14 @@ class SpeechCollator():
audios
.
append
(
audio
)
# [T, D]
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# text
# for training, text is token ids
# for training, text is token ids, else text is string, convert to unicode ord
# else text is string, convert to unicode ord
tokens
=
[]
tokens
=
[]
if
self
.
_
keep_transcription_text
:
if
self
.
keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
=
[
ord
(
t
)
for
t
in
text
]
tokens
=
[
ord
(
t
)
for
t
in
text
]
else
:
else
:
tokens
=
text
# token ids
tokens
=
text
# token ids
tokens
=
tokens
if
isinstance
(
tokens
,
np
.
ndarray
)
else
np
.
array
(
tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int64
)
tokens
,
dtype
=
np
.
int64
)
texts
.
append
(
tokens
)
texts
.
append
(
tokens
)
text_lens
.
append
(
tokens
.
shape
[
0
])
text_lens
.
append
(
tokens
.
shape
[
0
])
...
@@ -292,26 +205,162 @@ class SpeechCollator():
...
@@ -292,26 +205,162 @@ class SpeechCollator():
olens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
olens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
xs_pad
,
ilens
,
ys_pad
,
olens
return
utts
,
xs_pad
,
ilens
,
ys_pad
,
olens
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
class
SpeechCollator
(
SpeechCollatorBase
):
def
vocab_list
(
self
):
@
classmethod
return
self
.
_speech_featurizer
.
vocab_list
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
mean_std_filepath
=
""
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
spectrum_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
# feature dither
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'mean_std_filepath'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
'spectrum_type'
in
config
.
collator
assert
'n_fft'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
mean_std_filepath
=
config
.
collator
.
mean_std_filepath
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
spectrum_type
=
config
.
collator
.
spectrum_type
,
feat_dim
=
config
.
collator
.
feat_dim
,
delta_delta
=
config
.
collator
.
delta_delta
,
stride_ms
=
config
.
collator
.
stride_ms
,
window_ms
=
config
.
collator
.
window_ms
,
n_fft
=
config
.
collator
.
n_fft
,
max_freq
=
config
.
collator
.
max_freq
,
target_sample_rate
=
config
.
collator
.
target_sample_rate
,
use_dB_normalization
=
config
.
collator
.
use_dB_normalization
,
target_dB
=
config
.
collator
.
target_dB
,
dither
=
config
.
collator
.
dither
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
class
TripletSpeechCollator
(
SpeechCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
@
property
:param audio_file: Filepath or file object of audio file.
def
vocab_dict
(
self
):
:type audio_file: str | file
return
self
.
_speech_featurizer
.
vocab_dict
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
spectrum
,
translation_part
=
super
().
process_utterance
(
audio_file
,
translation
)
transcript_part
=
self
.
_speech_featurizer
.
text_featurize
(
transcript
,
self
.
keep_transcription_text
)
return
spectrum
,
translation_part
,
transcript_part
@
property
def
__call__
(
self
,
batch
):
def
text_feature
(
self
):
"""batch examples
return
self
.
_speech_featurizer
.
text_feature
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
@
property
Returns:
def
feature_size
(
self
):
tuple(audio, text, audio_lens, text_lens): batched data.
return
self
.
_speech_featurizer
.
feature_size
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
translation_text
=
[]
translation_text_lens
=
[]
transcription_text
=
[]
transcription_text_lens
=
[]
@
property
utts
=
[]
def
stride_ms
(
self
):
for
utt
,
audio
,
translation
,
transcription
in
batch
:
return
self
.
_speech_featurizer
.
stride_ms
audio
,
translation
,
transcription
=
self
.
process_utterance
(
audio
,
translation
,
transcription
)
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[[],
[]]
for
idx
,
text
in
enumerate
([
translation
,
transcription
]):
if
self
.
keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
[
idx
]
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
[
idx
]
=
text
# token ids
tokens
[
idx
]
=
np
.
array
(
tokens
[
idx
],
dtype
=
np
.
int64
)
translation_text
.
append
(
tokens
[
0
])
translation_text_lens
.
append
(
tokens
[
0
].
shape
[
0
])
transcription_text
.
append
(
tokens
[
1
])
transcription_text_lens
.
append
(
tokens
[
1
].
shape
[
0
])
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
padded_translation
=
pad_sequence
(
translation_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
translation_lens
=
np
.
array
(
translation_text_lens
).
astype
(
np
.
int64
)
padded_transcription
=
pad_sequence
(
transcription_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
transcription_lens
=
np
.
array
(
transcription_text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
(
padded_translation
,
padded_transcription
),
(
translation_lens
,
transcription_lens
)
deepspeech/io/collator_st.py
已删除
100644 → 0
浏览文件 @
32aaf403
此差异已折叠。
点击以展开。
deepspeech/io/reader.py
浏览文件 @
00d76542
...
@@ -321,6 +321,22 @@ class LoadInputsAndTargets():
...
@@ -321,6 +321,22 @@ class LoadInputsAndTargets():
raise
NotImplementedError
(
raise
NotImplementedError
(
"Not supported: loader_type={}"
.
format
(
filetype
))
"Not supported: loader_type={}"
.
format
(
filetype
))
def
file_type
(
self
,
filepath
):
suffix
=
filepath
.
split
(
":"
)[
0
].
split
(
'.'
)[
1
]
if
suffix
==
'ark'
:
return
'mat'
elif
suffix
==
'scp'
:
return
'scp'
elif
suffix
==
'npy'
:
return
'npy'
elif
suffix
==
'npz'
:
return
'npz'
elif
suffix
in
[
'wav'
,
'flac'
]:
# PCM16
return
'sound'
else
:
raise
ValueError
(
f
"Not support filetype:
{
suffix
}
"
)
class
SoundHDF5File
():
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
"""Collecting sound files to a HDF5 file
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录