Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
00d76542
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
00d76542
编写于
9月 28, 2021
作者:
J
Jackwaterveg
提交者:
GitHub
9月 28, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #864 from PaddlePaddle/collator
refactor st and asr collator
上级
32aaf403
f628e218
变更
8
展开全部
显示空白变更内容
内联
并排
Showing
8 changed file
with
322 addition
and
861 deletion
+322
-861
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+6
-17
deepspeech/frontend/audio.py
deepspeech/frontend/audio.py
+13
-7
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+29
-48
deepspeech/frontend/speech.py
deepspeech/frontend/speech.py
+8
-3
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+46
-0
deepspeech/io/collator.py
deepspeech/io/collator.py
+204
-155
deepspeech/io/collator_st.py
deepspeech/io/collator_st.py
+0
-631
deepspeech/io/reader.py
deepspeech/io/reader.py
+16
-0
未找到文件。
deepspeech/exps/u2_st/model.py
浏览文件 @
00d76542
...
...
@@ -28,10 +28,8 @@ from paddle import distributed as dist
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
deepspeech.io.collator_st
import
KaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
SpeechCollator
from
deepspeech.io.collator_st
import
TripletKaldiPrePorocessedCollator
from
deepspeech.io.collator_st
import
TripletSpeechCollator
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.collator
import
TripletSpeechCollator
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
TripletManifestDataset
from
deepspeech.io.sampler
import
SortagradBatchSampler
...
...
@@ -258,22 +256,13 @@ class U2STTrainer(Trainer):
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
dev_dataset
=
Dataset
.
from_config
(
config
)
if
config
.
collator
.
raw_wav
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
Collator
=
TripletSpeechCollator
TestCollator
=
SpeechCollator
else
:
TestCollator
=
Collator
=
SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else
:
if
config
.
model
.
model_conf
.
asr_weight
>
0.
:
Collator
=
TripletKaldiPrePorocessedCollator
TestCollator
=
KaldiPrePorocessedCollator
else
:
TestCollator
=
Collator
=
KaldiPrePorocessedCollator
collate_fn_train
=
Collator
.
from_config
(
config
)
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
Collator
.
from_config
(
config
)
...
...
deepspeech/frontend/audio.py
浏览文件 @
00d76542
...
...
@@ -24,8 +24,10 @@ import soundfile
import
soxbindings
as
sox
from
scipy
import
signal
from
.utility
import
subfile_from_tar
class
AudioSegment
(
object
):
class
AudioSegment
():
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
...
...
@@ -68,16 +70,20 @@ class AudioSegment(object):
self
.
duration
,
self
.
rms_db
))
@
classmethod
def
from_file
(
cls
,
file
):
def
from_file
(
cls
,
file
,
infos
=
None
):
"""Create audio segment from audio file.
:param filepath: Filepath or file object to audio file.
:type filepath: str|file
:return: Audio segment instance.
:rtype: AudioSegment
Args:
filepath (str|file): Filepath or file object to audio file.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
AudioSegment: Audio segment instance.
"""
if
isinstance
(
file
,
str
)
and
re
.
findall
(
r
".seqbin_\d+$"
,
file
):
return
cls
.
from_sequence_file
(
file
)
elif
isinstance
(
file
,
str
)
and
file
.
startswith
(
'tar:'
):
return
cls
.
from_file
(
subfile_from_tar
(
file
,
infos
))
else
:
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
...
...
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
00d76542
...
...
@@ -64,8 +64,12 @@ class SpeechFeaturizer():
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
):
self
.
_audio_featurizer
=
AudioFeaturizer
(
dither
=
1.0
,
maskctc
=
False
):
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
audio_feature
=
AudioFeaturizer
(
specgram_type
=
specgram_type
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
...
...
@@ -77,8 +81,12 @@ class SpeechFeaturizer():
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
,
dither
=
dither
)
self
.
_text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
,
spm_model_prefix
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
maskctc
=
maskctc
)
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
"""Extract features for speech segment.
...
...
@@ -94,60 +102,33 @@ class SpeechFeaturizer():
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
spec_feature
=
self
.
_audio_featurizer
.
featurize
(
speech_segment
)
spec_feature
=
self
.
audio_feature
.
featurize
(
speech_segment
)
if
keep_transcription_text
:
return
spec_feature
,
speech_segment
.
transcript
if
speech_segment
.
has_token
:
text_ids
=
speech_segment
.
token_ids
else
:
text_ids
=
self
.
_text_featurizer
.
featurize
(
speech_segment
.
transcript
)
text_ids
=
self
.
text_feature
.
featurize
(
speech_segment
.
transcript
)
return
spec_feature
,
text_ids
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return
self
.
_text_featurizer
.
vocab_list
def
text_featurize
(
self
,
text
,
keep_transcription_text
):
"""Extract features for speech segment.
@
property
def
vocab_dict
(
self
):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
return
self
.
_text_featurizer
.
vocab_dict
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
@
property
def
feature_size
(
self
):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return
self
.
_audio_featurizer
.
feature_size
Args:
text (str): text.
keep_transcription_text (bool): True, keep transcript text, False, token ids
@
property
def
stride_ms
(
self
):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
(str|List[int]): text, or list of token indices.
"""
return
self
.
_audio_featurizer
.
stride_ms
if
keep_transcription_text
:
return
text
@
property
def
text_feature
(
self
):
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
return
self
.
_text_featurizer
text_ids
=
self
.
text_feature
.
featurize
(
text
)
return
text_ids
deepspeech/frontend/speech.py
浏览文件 @
00d76542
...
...
@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment):
return
not
self
.
__eq__
(
other
)
@
classmethod
def
from_file
(
cls
,
filepath
,
transcript
,
tokens
=
None
,
token_ids
=
None
):
def
from_file
(
cls
,
filepath
,
transcript
,
tokens
=
None
,
token_ids
=
None
,
infos
=
None
):
"""Create speech segment from audio file and corresponding transcript.
Args:
...
...
@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment):
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio
=
AudioSegment
.
from_file
(
filepath
)
audio
=
AudioSegment
.
from_file
(
filepath
,
infos
)
return
cls
(
audio
.
samples
,
audio
.
sample_rate
,
transcript
,
tokens
,
token_ids
)
...
...
deepspeech/frontend/utility.py
浏览文件 @
00d76542
...
...
@@ -14,6 +14,7 @@
"""Contains data helper functions."""
import
json
import
math
import
tarfile
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
...
...
@@ -112,6 +113,51 @@ def read_manifest(
return
manifest
# Tar File read
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
def
parse_tar
(
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
subfile_from_tar
(
file
,
local_data
=
None
):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
local_data
is
None
:
local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
assert
isinstance
(
local_data
,
TarLocalData
)
if
'tar2info'
not
in
local_data
.
__dict__
:
local_data
.
tar2info
=
{}
if
'tar2object'
not
in
local_data
.
__dict__
:
local_data
.
tar2object
=
{}
if
tarpath
not
in
local_data
.
tar2info
:
fobj
,
infos
=
parse_tar
(
tarpath
)
local_data
.
tar2info
[
tarpath
]
=
infos
local_data
.
tar2object
[
tarpath
]
=
fobj
else
:
fobj
=
local_data
.
tar2object
[
tarpath
]
infos
=
local_data
.
tar2info
[
tarpath
]
return
fobj
.
extractfile
(
infos
[
filename
])
def
rms_to_db
(
rms
:
float
):
"""Root Mean Square to dB.
...
...
deepspeech/io/collator.py
浏览文件 @
00d76542
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
from
collections
import
namedtuple
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from
deepspeech.frontend.normalizer
import
FeatureNormalizer
from
deepspeech.frontend.speech
import
SpeechSegment
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
TarLocalData
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.io.utility
import
pad_list
from
deepspeech.utils.log
import
Log
__all__
=
[
"SpeechCollator"
]
__all__
=
[
"SpeechCollator"
,
"TripletSpeechCollator"
]
logger
=
Log
(
__name__
).
getlog
()
# namedtupe need global for pickle.
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
class
SpeechCollator
():
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
mean_std_filepath
=
""
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
specgram_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
# feature dither
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'mean_std_filepath'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
'specgram_type'
in
config
.
collator
assert
'n_fft'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
mean_std_filepath
=
config
.
collator
.
mean_std_filepath
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
specgram_type
=
config
.
collator
.
specgram_type
,
feat_dim
=
config
.
collator
.
feat_dim
,
delta_delta
=
config
.
collator
.
delta_delta
,
stride_ms
=
config
.
collator
.
stride_ms
,
window_ms
=
config
.
collator
.
window_ms
,
n_fft
=
config
.
collator
.
n_fft
,
max_freq
=
config
.
collator
.
max_freq
,
target_sample_rate
=
config
.
collator
.
target_sample_rate
,
use_dB_normalization
=
config
.
collator
.
use_dB_normalization
,
target_dB
=
config
.
collator
.
target_dB
,
dither
=
config
.
collator
.
dither
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
class
SpeechCollatorBase
():
def
__init__
(
self
,
aug_file
,
...
...
@@ -121,7 +41,7 @@ class SpeechCollator():
spm_model_prefix
,
random_seed
=
0
,
unit_type
=
"char"
,
spec
gra
m_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
spec
tru
m_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
...
...
@@ -146,7 +66,7 @@ class SpeechCollator():
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
spec
gra
m_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
spec
tru
m_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
...
...
@@ -159,23 +79,27 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self
.
_keep_transcription_text
=
keep_transcription_text
self
.
keep_transcription_text
=
keep_transcription_text
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
feat_dim
=
feat_dim
self
.
loader
=
LoadInputsAndTargets
()
# only for tar filetype
self
.
_local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
self
.
augmentation
=
AugmentationPipeline
(
augmentation_config
=
aug_file
.
read
(),
random_seed
=
random_seed
)
self
.
_normalizer
=
FeatureNormalizer
(
mean_std_filepath
)
if
mean_std_filepath
else
None
self
.
_stride_ms
=
stride_ms
self
.
_target_sample_rate
=
target_sample_rate
self
.
_speech_featurizer
=
SpeechFeaturizer
(
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
spec
gram_type
=
specgra
m_type
,
spec
trum_type
=
spectru
m_type
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
stride_ms
=
stride_ms
,
...
...
@@ -187,33 +111,11 @@ class SpeechCollator():
target_dB
=
target_dB
,
dither
=
dither
)
def
_parse_tar
(
self
,
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
_subfile_from_tar
(
self
,
file
):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
'tar2info'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2info
=
{}
if
'tar2object'
not
in
self
.
_local_data
.
__dict__
:
self
.
_local_data
.
tar2object
=
{}
if
tarpath
not
in
self
.
_local_data
.
tar2info
:
object
,
infoes
=
self
.
_parse_tar
(
tarpath
)
self
.
_local_data
.
tar2info
[
tarpath
]
=
infoes
self
.
_local_data
.
tar2object
[
tarpath
]
=
object
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
self
.
feature_size
=
self
.
_speech_featurizer
.
audio_feature
.
feature_size
self
.
text_feature
=
self
.
_speech_featurizer
.
text_feature
self
.
vocab_dict
=
self
.
text_feature
.
vocab_dict
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
self
.
vocab_size
=
self
.
text_feature
.
vocab_size
def
process_utterance
(
self
,
audio_file
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
...
...
@@ -226,23 +128,36 @@ class SpeechCollator():
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
filetype
=
self
.
loader
.
file_type
(
audio_file
)
if
filetype
!=
'sound'
:
spectrum
=
self
.
loader
.
_get_from_loader
(
audio_file
,
filetype
)
feat_dim
=
spectrum
.
shape
[
1
]
assert
feat_dim
==
self
.
feat_dim
,
f
"expect feat dim
{
self
.
feat_dim
}
, but got
{
feat_dim
}
"
if
self
.
keep_transcription_text
:
transcript_part
=
transcript
else
:
text_ids
=
self
.
text_feature
.
featurize
(
transcript
)
transcript_part
=
text_ids
else
:
# read audio
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
,
infos
=
self
.
_local_data
)
# audio augment
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
self
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
keep_transcription_text
)
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
# CMVN spectrum
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgra
m
)
spectrum
=
self
.
_normalizer
.
apply
(
spectru
m
)
# spec
gra
m augment
spec
gram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgra
m
)
return
spec
gra
m
,
transcript_part
# spec
tru
m augment
spec
trum
=
self
.
augmentation
.
transform_feature
(
spectru
m
)
return
spec
tru
m
,
transcript_part
def
__call__
(
self
,
batch
):
"""batch examples
...
...
@@ -272,16 +187,14 @@ class SpeechCollator():
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
# for training, text is token ids, else text is string, convert to unicode ord
tokens
=
[]
if
self
.
_
keep_transcription_text
:
if
self
.
keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
=
text
# token ids
tokens
=
tokens
if
isinstance
(
tokens
,
np
.
ndarray
)
else
np
.
array
(
tokens
,
dtype
=
np
.
int64
)
tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int64
)
texts
.
append
(
tokens
)
text_lens
.
append
(
tokens
.
shape
[
0
])
...
...
@@ -292,26 +205,162 @@ class SpeechCollator():
olens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
xs_pad
,
ilens
,
ys_pad
,
olens
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
class
SpeechCollator
(
SpeechCollatorBase
):
@
classmethod
def
params
(
cls
,
config
:
Optional
[
CfgNode
]
=
None
)
->
CfgNode
:
default
=
CfgNode
(
dict
(
augmentation_config
=
""
,
random_seed
=
0
,
mean_std_filepath
=
""
,
unit_type
=
"char"
,
vocab_filepath
=
""
,
spm_model_prefix
=
""
,
spectrum_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
stride_ms
=
10.0
,
# ms
window_ms
=
20.0
,
# ms
n_fft
=
None
,
# fft points
max_freq
=
None
,
# None for samplerate/2
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
# feature dither
keep_transcription_text
=
False
))
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
return
default
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert
'augmentation_config'
in
config
.
collator
assert
'keep_transcription_text'
in
config
.
collator
assert
'mean_std_filepath'
in
config
.
collator
assert
'vocab_filepath'
in
config
.
collator
assert
'spectrum_type'
in
config
.
collator
assert
'n_fft'
in
config
.
collator
assert
config
.
collator
if
isinstance
(
config
.
collator
.
augmentation_config
,
(
str
,
bytes
)):
if
config
.
collator
.
augmentation_config
:
aug_file
=
io
.
open
(
config
.
collator
.
augmentation_config
,
mode
=
'r'
,
encoding
=
'utf8'
)
else
:
aug_file
=
io
.
StringIO
(
initial_value
=
'{}'
,
newline
=
''
)
else
:
aug_file
=
config
.
collator
.
augmentation_config
assert
isinstance
(
aug_file
,
io
.
StringIO
)
speech_collator
=
cls
(
aug_file
=
aug_file
,
random_seed
=
0
,
mean_std_filepath
=
config
.
collator
.
mean_std_filepath
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
,
spectrum_type
=
config
.
collator
.
spectrum_type
,
feat_dim
=
config
.
collator
.
feat_dim
,
delta_delta
=
config
.
collator
.
delta_delta
,
stride_ms
=
config
.
collator
.
stride_ms
,
window_ms
=
config
.
collator
.
window_ms
,
n_fft
=
config
.
collator
.
n_fft
,
max_freq
=
config
.
collator
.
max_freq
,
target_sample_rate
=
config
.
collator
.
target_sample_rate
,
use_dB_normalization
=
config
.
collator
.
use_dB_normalization
,
target_dB
=
config
.
collator
.
target_dB
,
dither
=
config
.
collator
.
dither
,
keep_transcription_text
=
config
.
collator
.
keep_transcription_text
)
return
speech_collator
class
TripletSpeechCollator
(
SpeechCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
spectrum
,
translation_part
=
super
().
process_utterance
(
audio_file
,
translation
)
transcript_part
=
self
.
_speech_featurizer
.
text_featurize
(
transcript
,
self
.
keep_transcription_text
)
return
spectrum
,
translation_part
,
transcript_part
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
def
__call__
(
self
,
batch
):
"""batch examples
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
translation_text
=
[]
translation_text_lens
=
[]
transcription_text
=
[]
transcription_text_lens
=
[]
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
utts
=
[]
for
utt
,
audio
,
translation
,
transcription
in
batch
:
audio
,
translation
,
transcription
=
self
.
process_utterance
(
audio
,
translation
,
transcription
)
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[[],
[]]
for
idx
,
text
in
enumerate
([
translation
,
transcription
]):
if
self
.
keep_transcription_text
:
assert
isinstance
(
text
,
str
),
(
type
(
text
),
text
)
tokens
[
idx
]
=
[
ord
(
t
)
for
t
in
text
]
else
:
tokens
[
idx
]
=
text
# token ids
tokens
[
idx
]
=
np
.
array
(
tokens
[
idx
],
dtype
=
np
.
int64
)
translation_text
.
append
(
tokens
[
0
])
translation_text_lens
.
append
(
tokens
[
0
].
shape
[
0
])
transcription_text
.
append
(
tokens
[
1
])
transcription_text_lens
.
append
(
tokens
[
1
].
shape
[
0
])
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
padded_translation
=
pad_sequence
(
translation_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
translation_lens
=
np
.
array
(
translation_text_lens
).
astype
(
np
.
int64
)
padded_transcription
=
pad_sequence
(
transcription_text
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
transcription_lens
=
np
.
array
(
transcription_text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
(
padded_translation
,
padded_transcription
),
(
translation_lens
,
transcription_lens
)
deepspeech/io/collator_st.py
已删除
100644 → 0
浏览文件 @
32aaf403
此差异已折叠。
点击以展开。
deepspeech/io/reader.py
浏览文件 @
00d76542
...
...
@@ -321,6 +321,22 @@ class LoadInputsAndTargets():
raise
NotImplementedError
(
"Not supported: loader_type={}"
.
format
(
filetype
))
def
file_type
(
self
,
filepath
):
suffix
=
filepath
.
split
(
":"
)[
0
].
split
(
'.'
)[
1
]
if
suffix
==
'ark'
:
return
'mat'
elif
suffix
==
'scp'
:
return
'scp'
elif
suffix
==
'npy'
:
return
'npy'
elif
suffix
==
'npz'
:
return
'npz'
elif
suffix
in
[
'wav'
,
'flac'
]:
# PCM16
return
'sound'
else
:
raise
ValueError
(
f
"Not support filetype:
{
suffix
}
"
)
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录