Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
a1e1463d
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
11
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a1e1463d
编写于
2月 06, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dataset' into 'master'
update statset and datacargo's design See merge request !7
上级
2ca5c810
837749a3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
219 addition
and
35 deletion
+219
-35
parakeet/__init__.py
parakeet/__init__.py
+2
-0
parakeet/data/datacargo.py
parakeet/data/datacargo.py
+30
-15
parakeet/data/dataset.py
parakeet/data/dataset.py
+187
-20
未找到文件。
parakeet/__init__.py
浏览文件 @
a1e1463d
__version__
=
"0.0.0"
from
.
import
data
,
g2p
,
models
,
modules
,
utils
parakeet/data/datacargo.py
浏览文件 @
a1e1463d
from
.sampler
import
SequentialSampler
,
RandomSampler
,
BatchSampler
class
DataCargo
(
object
):
def
__init__
(
self
,
dataset
,
batch_size
=
1
,
sampler
=
None
,
shuffle
=
False
,
batch_sampler
=
None
,
drop_last
=
False
):
def
__init__
(
self
,
dataset
,
batch_fn
=
None
,
batch_size
=
1
,
sampler
=
None
,
shuffle
=
False
,
batch_sampler
=
None
,
drop_last
=
False
):
self
.
dataset
=
dataset
self
.
batch_fn
=
batch_fn
or
self
.
dataset
.
_batch_examples
if
batch_sampler
is
not
None
:
# auto_collation with custom batch_sampler
if
batch_size
!=
1
or
shuffle
or
sampler
is
not
None
or
drop_last
:
...
...
@@ -15,7 +23,8 @@ class DataCargo(object):
drop_last
=
False
shuffle
=
False
elif
batch_size
is
None
:
raise
ValueError
(
'batch sampler is none. then batch size must not be none.'
)
raise
ValueError
(
'batch sampler is none. then batch size must not be none.'
)
elif
sampler
is
None
:
if
shuffle
:
sampler
=
RandomSampler
(
dataset
)
...
...
@@ -23,18 +32,20 @@ class DataCargo(object):
sampler
=
SequentialSampler
(
dataset
)
# auto_collation without custom batch_sampler
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
else
:
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
,
drop_last
)
self
.
batch_size
=
batch_size
self
.
drop_last
=
drop_last
self
.
sampler
=
sampler
self
.
batch_sampler
=
batch_sampler
def
__iter__
(
self
):
return
DataIterator
(
self
)
def
__call__
(
self
):
return
DataIterator
(
self
)
@
property
def
_auto_collation
(
self
):
# we will auto batching
...
...
@@ -49,26 +60,30 @@ class DataCargo(object):
def
__len__
(
self
):
return
len
(
self
.
_index_sampler
)
class
DataIterator
(
object
):
def
__init__
(
self
,
loader
):
self
.
loader
=
loader
self
.
_dataset
=
loader
.
dataset
self
.
_batch_fn
=
loader
.
batch_fn
self
.
_index_sampler
=
loader
.
_index_sampler
self
.
_sampler_iter
=
iter
(
self
.
_index_sampler
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
index
=
self
.
_next_index
()
# may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch
=
[
self
.
_dataset
[
i
]
for
i
in
index
]
# we can abstract it, too to use dynamic batch size
minibatch
=
self
.
_dataset
.
_batch_examples
(
minibatch
)
# list[Example] -> Batch
index
=
self
.
_next_index
(
)
# may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
minibatch
=
[
self
.
_dataset
[
i
]
for
i
in
index
]
# we can abstract it, too to use dynamic batch size
minibatch
=
self
.
_batch_fn
(
minibatch
)
# list[Example] -> Batch
return
minibatch
def
_next_index
(
self
):
return
next
(
self
.
_sampler_iter
)
def
__len__
(
self
):
return
len
(
self
.
_index_sampler
)
parakeet/data/dataset.py
浏览文件 @
a1e1463d
class
Dataset
(
object
):
def
__init__
(
self
):
pass
def
_load_metadata
(
self
):
raise
NotImplementedError
def
_get_example
(
self
):
"""return a Record (or Example, Instance according to your glossary)"""
raise
NotImplementedError
def
_batch_examples
(
self
,
minibatch
):
"""get a list of examples, return a batch, whose structure is the same as an example"""
raise
NotImplementedError
def
_prepare_metadata
(
self
):
raise
NotImplementedError
import
six
import
numpy
as
np
class
DatasetMixin
(
object
):
"""standard indexing interface for dataset."""
def
__getitem__
(
self
,
index
):
if
isinstance
(
index
,
slice
):
start
,
stop
,
step
=
index
.
indices
(
len
(
self
))
return
[
self
.
get_example
(
i
)
for
i
in
six
.
moves
.
range
(
start
,
stop
,
step
)
]
elif
isinstance
(
index
,
(
list
,
np
.
ndarray
)):
return
[
self
.
get_example
(
i
)
for
i
in
index
]
else
:
# assumes it an integer
return
self
.
get_example
(
index
)
def
get_example
(
self
,
i
):
raise
NotImplementedError
def
__
iter
__
(
self
):
def
__
len
__
(
self
):
raise
NotImplementedError
def
__iter__
(
self
):
for
i
in
range
(
len
(
self
)):
yield
self
.
get_example
(
i
)
class
TransformDataset
(
DatasetMixin
):
"""Transform a dataset to another with a transform."""
def
__init__
(
self
,
dataset
,
transform
):
self
.
_dataset
=
dataset
self
.
_transform
=
transform
def
__len__
(
self
):
return
len
(
self
.
_dataset
)
def
get_example
(
self
,
i
):
# CAUTION: only int is supported?
# CAUTION: dataset support support __getitem__ and __len__
in_data
=
self
.
_dataset
[
i
]
return
self
.
_transform
(
in_data
)
class
TupleDataset
(
object
):
def
__init__
(
self
,
*
datasets
):
if
not
datasets
:
raise
ValueError
(
"no datasets are given"
)
length
=
len
(
datasets
[
0
])
for
i
,
dataset
in
enumerate
(
datasets
):
if
len
(
datasets
)
!=
length
:
raise
ValueError
(
"all the datasets should have the same length."
"dataset {} has a different length"
.
format
(
i
))
self
.
_datasets
=
datasets
self
.
_length
=
length
def
__getitem__
(
self
,
index
):
# SOA
batches
=
[
dataset
[
index
]
for
dataset
in
self
.
_datasets
]
if
isinstance
(
index
,
slice
):
length
=
len
(
batches
[
0
])
# AOS
return
[
tuple
([
batch
[
i
]
for
batch
in
batches
])
for
i
in
six
.
moves
.
range
(
length
)
]
else
:
return
tuple
(
batches
)
def
__len__
(
self
):
return
self
.
_length
class
DictDataset
(
object
):
def
__init__
(
self
,
**
datasets
):
if
not
datasets
:
raise
ValueError
(
"no datasets are given"
)
length
=
None
for
key
,
dataset
in
six
.
iteritems
(
datasets
):
if
length
is
None
:
length
=
len
(
dataset
)
elif
len
(
datasets
)
!=
length
:
raise
ValueError
(
"all the datasets should have the same length."
"dataset {} has a different length"
.
format
(
key
))
self
.
_datasets
=
datasets
self
.
_length
=
length
def
__getitem__
(
self
,
index
):
batches
=
{
key
:
dataset
[
index
]
for
key
,
dataset
in
six
.
iteritems
(
self
.
_datasets
)
}
if
isinstance
(
index
,
slice
):
length
=
len
(
six
.
next
(
six
.
itervalues
(
batches
)))
return
[{
key
:
batch
[
i
]
for
key
,
batch
in
six
.
iteritems
(
batches
)}
for
i
in
six
.
moves
.
range
(
length
)]
else
:
return
batches
class
SliceDataset
(
DatasetMixin
):
def
__init__
(
self
,
dataset
,
start
,
finish
,
order
=
None
):
if
start
<
0
or
finish
>
len
(
dataset
):
raise
ValueError
(
"subset overruns the dataset."
)
self
.
_dataset
=
dataset
self
.
_start
=
start
self
.
_finish
=
finish
self
.
_size
=
finish
-
start
if
order
is
not
None
and
len
(
order
)
!=
len
(
dataset
):
raise
ValueError
(
"order should have the same length as the dataset"
"len(order) = {} which does not euqals len(dataset) = {} "
.
format
(
len
(
order
),
len
(
dataset
)))
self
.
_order
=
order
def
len
(
self
):
return
self
.
_size
def
get_example
(
self
,
i
):
if
i
>=
0
:
if
i
>=
self
.
_size
:
raise
IndexError
(
'dataset index out of range'
)
index
=
self
.
_start
+
i
else
:
if
i
<
-
self
.
_size
:
raise
IndexError
(
'dataset index out of range'
)
index
=
self
.
_finish
+
i
if
self
.
_order
is
not
None
:
index
=
self
.
_order
[
index
]
return
self
.
_dataset
[
index
]
class
SubsetDataset
(
DatasetMixin
):
def
__init__
(
self
,
dataset
,
indices
):
self
.
_dataset
=
dataset
if
len
(
indices
)
>
len
(
dataset
):
raise
ValueError
(
"subset's size larger that dataset's size!"
)
self
.
_indices
=
indices
self
.
_size
=
len
(
indices
)
def
__len__
(
self
):
return
self
.
_size
def
get_example
(
self
,
i
):
index
=
self
.
_indices
[
i
]
return
self
.
_dataset
[
index
]
class
FilterDataset
(
DatasetMixin
):
def
__init__
(
self
,
dataset
,
filter_fn
):
self
.
_dataset
=
dataset
self
.
_indices
=
[
i
for
i
in
range
(
len
(
dataset
))
if
filter_fn
(
dataset
[
i
])
]
self
.
_size
=
len
(
self
.
_indices
)
def
__len__
(
self
):
return
self
.
_size
def
get_example
(
self
,
i
):
index
=
self
.
_indices
[
i
]
return
self
.
_dataset
[
index
]
class
ChainDataset
(
DatasetMixin
):
def
__init__
(
self
,
*
datasets
):
self
.
_datasets
=
datasets
def
__len__
(
self
):
return
sum
(
len
(
dataset
)
for
dataset
in
self
.
_datasets
)
def
get_example
(
self
,
i
):
if
i
<
0
:
raise
IndexError
(
"ChainDataset doesnot support negative indexing."
)
for
dataset
in
self
.
_datasets
:
if
i
<
len
(
dataset
):
return
dataset
[
i
]
i
-=
len
(
dataset
)
raise
IndexError
(
"dataset index out of range"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录