Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
4739b5a0
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4739b5a0
编写于
9月 22, 2021
作者:
H
Hui Zhang
提交者:
GitHub
9月 22, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #850 from PaddlePaddle/dataset
batch WaveDataset
上级
ecb5d4f8
98b15eda
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
132 addition
and
8 deletion
+132
-8
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+2
-2
deepspeech/io/collator.py
deepspeech/io/collator.py
+0
-4
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+128
-0
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+2
-2
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
4739b5a0
...
...
@@ -199,11 +199,11 @@ class U2Trainer(Trainer):
report
(
"Rank"
,
dist
.
get_rank
())
report
(
"epoch"
,
self
.
epoch
)
report
(
'step'
,
self
.
iteration
)
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
"lr"
,
self
.
lr_scheduler
())
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
...
...
deepspeech/io/collator.py
浏览文件 @
4739b5a0
...
...
@@ -292,10 +292,6 @@ class SpeechCollator():
olens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
xs_pad
,
ilens
,
ys_pad
,
olens
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
...
...
deepspeech/io/dataset.py
浏览文件 @
4739b5a0
...
...
@@ -147,3 +147,131 @@ class TransformDataset(Dataset):
def
__getitem__
(
self
,
idx
):
"""[] operator."""
return
self
.
converter
([
self
.
reader
(
self
.
data
[
idx
],
return_uttid
=
True
)])
class
AudioDataset
(
Dataset
):
def
__init__
(
self
,
data_file
,
max_length
=
10240
,
min_length
=
0
,
token_max_length
=
200
,
token_min_length
=
1
,
batch_type
=
'static'
,
batch_size
=
1
,
max_frames_in_batch
=
0
,
sort
=
True
,
raw_wav
=
True
,
stride_ms
=
10
):
"""Dataset for loading audio data.
Attributes::
data_file: input data file
Plain text data file, each line contains following 7 fields,
which is split by '
\t
':
utt:utt1
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
text:i love you
token: i <space> l o v e <space> y o u
tokenid: int id of this token
token_shape: M,N # M is the number of token, N is vocab size
max_length: drop utterance which is greater than max_length(10ms), unit 10ms.
min_length: drop utterance which is less than min_length(10ms), unit 10ms.
token_max_length: drop utterance which is greater than token_max_length,
especially when use char unit for english modeling
token_min_length: drop utterance which is less than token_max_length
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
batch_size: number of utterances in a batch,
it's for static batch size.
max_frames_in_batch: max feature frames in a batch,
when batch_type is dynamic, it's for dynamic batch size.
Then batch_size is ignored, we will keep filling the
batch until the total frames in batch up to max_frames_in_batch.
sort: whether to sort all data, so the utterance with the same
length could be filled in a same batch.
raw_wav: use raw wave or extracted featute.
if raw wave is used, dynamic waveform-level augmentation could be used
and the feature is extracted by torchaudio.
if extracted featute(e.g. by kaldi) is used, only feature-level
augmentation such as specaug could be used.
"""
assert
batch_type
in
[
'static'
,
'dynamic'
]
# read manifest
data
=
read_manifest
(
data_file
)
if
sort
:
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"feat_shape"
][
0
])
if
raw_wav
:
assert
data
[
0
][
'feat'
].
split
(
':'
)[
0
].
splitext
()[
-
1
]
not
in
(
'.ark'
,
'.scp'
)
data
=
map
(
lambda
x
:
(
float
(
x
[
'feat_shape'
][
0
])
*
1000
/
stride_ms
))
self
.
input_dim
=
data
[
0
][
'feat_shape'
][
1
]
self
.
output_dim
=
data
[
0
][
'token_shape'
][
1
]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data
=
[]
for
i
in
range
(
len
(
data
)):
length
=
data
[
i
][
'feat_shape'
][
0
]
token_length
=
data
[
i
][
'token_shape'
][
0
]
# remove too lang or too short utt for both input and output
# to prevent from out of memory
if
length
>
max_length
or
length
<
min_length
:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
elif
token_length
>
token_max_length
or
token_length
<
token_min_length
:
pass
else
:
valid_data
.
append
(
data
[
i
])
data
=
valid_data
self
.
minibatch
=
[]
num_data
=
len
(
data
)
# Dynamic batch size
if
batch_type
==
'dynamic'
:
assert
(
max_frames_in_batch
>
0
)
self
.
minibatch
.
append
([])
num_frames_in_batch
=
0
for
i
in
range
(
num_data
):
length
=
data
[
i
][
'feat_shape'
][
0
]
num_frames_in_batch
+=
length
if
num_frames_in_batch
>
max_frames_in_batch
:
self
.
minibatch
.
append
([])
num_frames_in_batch
=
length
self
.
minibatch
[
-
1
].
append
(
data
[
i
])
# Static batch size
else
:
cur
=
0
while
cur
<
num_data
:
end
=
min
(
cur
+
batch_size
,
num_data
)
item
=
[]
for
i
in
range
(
cur
,
end
):
item
.
append
(
data
[
i
])
self
.
minibatch
.
append
(
item
)
cur
=
end
def
__len__
(
self
):
return
len
(
self
.
minibatch
)
def
__getitem__
(
self
,
idx
):
instance
=
self
.
minibatch
[
idx
]
return
instance
[
"utt"
],
instance
[
"feat"
],
instance
[
"text"
]
deepspeech/training/trainer.py
浏览文件 @
4739b5a0
...
...
@@ -247,11 +247,11 @@ class Trainer():
report
(
"Rank"
,
dist
.
get_rank
())
report
(
"epoch"
,
self
.
epoch
)
report
(
'step'
,
self
.
iteration
)
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
"lr"
,
self
.
lr_scheduler
())
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录