Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f9a6970a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
f9a6970a
编写于
8月 22, 2022
作者:
小湉湉
提交者:
GitHub
8月 22, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2263 from oyjxer/pc
[TTS]add ernie-sat sampler
上级
c3865f2a
b9be2bd6
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
188 addition
and
5 deletion
+188
-5
paddlespeech/t2s/datasets/sampler.py
paddlespeech/t2s/datasets/sampler.py
+181
-0
paddlespeech/t2s/exps/ernie_sat/normalize.py
paddlespeech/t2s/exps/ernie_sat/normalize.py
+1
-1
paddlespeech/t2s/exps/ernie_sat/preprocess.py
paddlespeech/t2s/exps/ernie_sat/preprocess.py
+1
-1
paddlespeech/t2s/exps/ernie_sat/train.py
paddlespeech/t2s/exps/ernie_sat/train.py
+2
-1
paddlespeech/t2s/training/updaters/standard_updater.py
paddlespeech/t2s/training/updaters/standard_updater.py
+3
-2
未找到文件。
paddlespeech/t2s/datasets/sampler.py
0 → 100644
浏览文件 @
f9a6970a
import
paddle
import
math
import
numpy
as
np
from
paddle.io
import
BatchSampler
class
ErnieSATSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.distributed.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.distributed.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for data in sampler:
# do something
break
"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
if
num_replicas
is
not
None
:
assert
isinstance
(
num_replicas
,
int
)
and
num_replicas
>
0
,
\
"num_replicas should be a positive integer"
self
.
nranks
=
num_replicas
else
:
self
.
nranks
=
ParallelEnv
().
nranks
if
rank
is
not
None
:
assert
isinstance
(
rank
,
int
)
and
rank
>=
0
,
\
"rank should be a non-negative integer"
self
.
local_rank
=
rank
else
:
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
drop_last
=
drop_last
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
def
_get_indices_by_batch_size
(
indices
):
subsampled_indices
=
[]
last_batch_size
=
self
.
total_size
%
(
self
.
batch_size
*
self
.
nranks
)
assert
last_batch_size
%
self
.
nranks
==
0
last_local_batch_size
=
last_batch_size
//
self
.
nranks
for
i
in
range
(
self
.
local_rank
*
self
.
batch_size
,
len
(
indices
)
-
last_batch_size
,
self
.
batch_size
*
self
.
nranks
):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
if
self
.
nranks
>
1
:
indices
=
_get_indices_by_batch_size
(
indices
)
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices_list
=
[]
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
batch_indices_list
.
append
(
batch_indices
)
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
batch_indices_list
.
append
(
batch_indices
)
if
self
.
shuffle
:
np
.
random
.
RandomState
(
self
.
epoch
).
shuffle
(
batch_indices_list
)
self
.
epoch
+=
1
for
batch_indices
in
batch_indices_list
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
def
set_epoch
(
self
,
epoch
):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self
.
epoch
=
epoch
paddlespeech/t2s/exps/ernie_sat/normalize.py
浏览文件 @
f9a6970a
...
...
@@ -118,7 +118,7 @@ def main():
record
[
"spk_emb"
]
=
str
(
item
[
"spk_emb"
])
output_metadata
.
append
(
record
)
output_metadata
.
sort
(
key
=
itemgetter
(
'
utt_id
'
))
output_metadata
.
sort
(
key
=
itemgetter
(
'
speech_lengths
'
))
output_metadata_path
=
Path
(
args
.
dumpdir
)
/
"metadata.jsonl"
with
jsonlines
.
open
(
output_metadata_path
,
'w'
)
as
writer
:
for
item
in
output_metadata
:
...
...
paddlespeech/t2s/exps/ernie_sat/preprocess.py
浏览文件 @
f9a6970a
...
...
@@ -165,7 +165,7 @@ def process_sentences(config,
if
record
:
results
.
append
(
record
)
results
.
sort
(
key
=
itemgetter
(
"
utt_id
"
))
results
.
sort
(
key
=
itemgetter
(
"
speech_lengths
"
))
# replace 'w' with 'a' to write from the end of file
with
jsonlines
.
open
(
output_dir
/
"metadata.jsonl"
,
'a'
)
as
writer
:
for
item
in
results
:
...
...
paddlespeech/t2s/exps/ernie_sat/train.py
浏览文件 @
f9a6970a
...
...
@@ -31,6 +31,7 @@ from yacs.config import CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_erniesat_collate_fn
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSAT
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATEvaluator
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATUpdater
...
...
@@ -86,7 +87,7 @@ def train_sp(args, config):
seg_emb
=
config
.
model
[
'enc_input_layer'
]
==
'sega_mlm'
,
text_masking
=
config
[
"model"
][
"text_masking"
])
train_sampler
=
DistributedBatch
Sampler
(
train_sampler
=
ErnieSAT
Sampler
(
train_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
,
...
...
paddlespeech/t2s/training/updaters/standard_updater.py
浏览文件 @
f9a6970a
...
...
@@ -27,7 +27,7 @@ from timer import timer
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
...
...
@@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase):
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler
=
self
.
dataloader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
)
\
or
isinstance
(
batch_sampler
,
ErnieSATSampler
):
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录