Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
4af774d8
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4af774d8
编写于
8月 18, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dataloader; check augmenter base class type
上级
64cf538e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
16 deletion
+39
-16
deepspeech/frontend/augmentor/augmentation.py
deepspeech/frontend/augmentor/augmentation.py
+2
-0
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+37
-16
未找到文件。
deepspeech/frontend/augmentor/augmentation.py
浏览文件 @
4af774d8
...
...
@@ -18,6 +18,7 @@ from inspect import signature
import
numpy
as
np
from
deepspeech.frontend.augmentor.base
import
AugmentorBase
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.log
import
Log
...
...
@@ -209,6 +210,7 @@ class AugmentationPipeline():
def
_get_augmentor
(
self
,
augmentor_type
,
params
):
"""Return an augmentation model by the type name, and pass in params."""
class_obj
=
dynamic_import
(
augmentor_type
,
import_alias
)
assert
issubclass
(
class_obj
,
AugmentorBase
)
try
:
obj
=
class_obj
(
self
.
_rng
,
**
params
)
except
Exception
:
...
...
deepspeech/io/dataloader.py
浏览文件 @
4af774d8
...
...
@@ -15,8 +15,8 @@ from paddle.io import DataLoader
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.io.batchfy
import
make_batchset
from
deepspeech.io.converter
import
CustomConverter
from
deepspeech.io.dataset
import
TransformDataset
from
deepspeech.io.reader
import
CustomConverter
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.utils.log
import
Log
...
...
@@ -46,7 +46,6 @@ class BatchDataLoader():
num_encs
:
int
=
1
):
self
.
json_file
=
json_file
self
.
train_mode
=
train_mode
self
.
use_sortagrad
=
sortagrad
==
-
1
or
sortagrad
>
0
self
.
batch_size
=
batch_size
self
.
maxlen_in
=
maxlen_in
...
...
@@ -56,20 +55,17 @@ class BatchDataLoader():
self
.
batch_frames_in
=
batch_frames_in
self
.
batch_frames_out
=
batch_frames_out
self
.
batch_frames_inout
=
batch_frames_inout
self
.
subsampling_factor
=
subsampling_factor
self
.
num_encs
=
num_encs
self
.
preprocess_conf
=
preprocess_conf
self
.
n_iter_processes
=
n_iter_processes
# read json data
data_json
=
read_manifest
(
json_file
)
logger
.
info
(
f
"load
{
json_file
}
file."
)
self
.
data_json
=
read_manifest
(
json_file
)
# make minibatch list (variable length)
self
.
data
=
make_batchset
(
data_json
,
self
.
minibaches
=
make_batchset
(
self
.
data_json
,
batch_size
,
maxlen_in
,
maxlen_out
,
...
...
@@ -83,9 +79,9 @@ class BatchDataLoader():
batch_frames_inout
=
batch_frames_inout
,
iaxis
=
0
,
oaxis
=
0
,
)
logger
.
info
(
f
"batchfy data
{
json_file
}
:
{
len
(
self
.
data
)
}
."
)
self
.
load
=
LoadInputsAndTargets
(
# data reader
self
.
reader
=
LoadInputsAndTargets
(
mode
=
"asr"
,
load_output
=
True
,
preprocess_conf
=
preprocess_conf
,
...
...
@@ -96,7 +92,7 @@ class BatchDataLoader():
# Setup a converter
if
num_encs
==
1
:
self
.
converter
=
CustomConverter
(
subsampling_factor
=
subsampling_factor
,
dtype
=
dtype
)
subsampling_factor
=
subsampling_factor
,
dtype
=
np
.
float32
)
else
:
assert
NotImplementedError
(
"not impl CustomConverterMulEnc."
)
...
...
@@ -104,14 +100,39 @@ class BatchDataLoader():
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self
.
train_loader
=
DataLoader
(
dataset
=
TransformDataset
(
self
.
data
,
lambda
data
:
self
.
converter
([
self
.
load
(
data
,
return_uttid
=
True
)])),
self
.
dataset
=
TransformDataset
(
self
.
minibaches
,
lambda
data
:
self
.
converter
([
self
.
reader
(
data
,
return_uttid
=
True
)]))
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
use_sortagrad
if
train_mode
else
False
,
collate_fn
=
lambda
x
:
x
[
0
],
num_workers
=
n_iter_processes
,
)
logger
.
info
(
f
"dataloader for
{
json_file
}
."
)
def
__repr__
(
self
):
return
f
"DataLoader
{
self
.
json_file
}
-
{
self
.
train_mode
}
-
{
self
.
use_sortagrad
}
"
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
echo
+=
f
"train_mode:
{
self
.
train_mode
}
, "
echo
+=
f
"sortagrad:
{
self
.
use_sortagrad
}
, "
echo
+=
f
"batch_size:
{
self
.
batch_size
}
, "
echo
+=
f
"maxlen_in:
{
self
.
maxlen_in
}
, "
echo
+=
f
"maxlen_out:
{
self
.
maxlen_out
}
, "
echo
+=
f
"batch_count:
{
self
.
batch_count
}
, "
echo
+=
f
"batch_bins:
{
self
.
batch_bins
}
, "
echo
+=
f
"batch_frames_in:
{
self
.
batch_frames_in
}
, "
echo
+=
f
"batch_frames_out:
{
self
.
batch_frames_out
}
, "
echo
+=
f
"batch_frames_inout:
{
self
.
batch_frames_inout
}
, "
echo
+=
f
"subsampling_factor:
{
self
.
subsampling_factor
}
, "
echo
+=
f
"num_encs:
{
self
.
num_encs
}
, "
echo
+=
f
"num_workers:
{
self
.
n_iter_processes
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
__iter__
(
self
):
return
self
.
dataloader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录