Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b9be2bd6
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b9be2bd6
编写于
8月 17, 2022
作者:
P
pangchao04
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ernie-sat sampler
上级
83e10fad
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
188 addition
and
5 deletion
+188
-5
paddlespeech/t2s/datasets/sampler.py
paddlespeech/t2s/datasets/sampler.py
+181
-0
paddlespeech/t2s/exps/ernie_sat/normalize.py
paddlespeech/t2s/exps/ernie_sat/normalize.py
+1
-1
paddlespeech/t2s/exps/ernie_sat/preprocess.py
paddlespeech/t2s/exps/ernie_sat/preprocess.py
+1
-1
paddlespeech/t2s/exps/ernie_sat/train.py
paddlespeech/t2s/exps/ernie_sat/train.py
+2
-1
paddlespeech/t2s/training/updaters/standard_updater.py
paddlespeech/t2s/training/updaters/standard_updater.py
+3
-2
未找到文件。
paddlespeech/t2s/datasets/sampler.py
0 → 100644
浏览文件 @
b9be2bd6
import
paddle
import
math
import
numpy
as
np
from
paddle.io
import
BatchSampler
class
ErnieSATSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.distributed.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.distributed.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for data in sampler:
# do something
break
"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
if
num_replicas
is
not
None
:
assert
isinstance
(
num_replicas
,
int
)
and
num_replicas
>
0
,
\
"num_replicas should be a positive integer"
self
.
nranks
=
num_replicas
else
:
self
.
nranks
=
ParallelEnv
().
nranks
if
rank
is
not
None
:
assert
isinstance
(
rank
,
int
)
and
rank
>=
0
,
\
"rank should be a non-negative integer"
self
.
local_rank
=
rank
else
:
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
drop_last
=
drop_last
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
def
_get_indices_by_batch_size
(
indices
):
subsampled_indices
=
[]
last_batch_size
=
self
.
total_size
%
(
self
.
batch_size
*
self
.
nranks
)
assert
last_batch_size
%
self
.
nranks
==
0
last_local_batch_size
=
last_batch_size
//
self
.
nranks
for
i
in
range
(
self
.
local_rank
*
self
.
batch_size
,
len
(
indices
)
-
last_batch_size
,
self
.
batch_size
*
self
.
nranks
):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
if
self
.
nranks
>
1
:
indices
=
_get_indices_by_batch_size
(
indices
)
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices_list
=
[]
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
batch_indices_list
.
append
(
batch_indices
)
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
batch_indices_list
.
append
(
batch_indices
)
if
self
.
shuffle
:
np
.
random
.
RandomState
(
self
.
epoch
).
shuffle
(
batch_indices_list
)
self
.
epoch
+=
1
for
batch_indices
in
batch_indices_list
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
def
set_epoch
(
self
,
epoch
):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self
.
epoch
=
epoch
paddlespeech/t2s/exps/ernie_sat/normalize.py
浏览文件 @
b9be2bd6
...
@@ -118,7 +118,7 @@ def main():
...
@@ -118,7 +118,7 @@ def main():
record
[
"spk_emb"
]
=
str
(
item
[
"spk_emb"
])
record
[
"spk_emb"
]
=
str
(
item
[
"spk_emb"
])
output_metadata
.
append
(
record
)
output_metadata
.
append
(
record
)
output_metadata
.
sort
(
key
=
itemgetter
(
'
utt_id
'
))
output_metadata
.
sort
(
key
=
itemgetter
(
'
speech_lengths
'
))
output_metadata_path
=
Path
(
args
.
dumpdir
)
/
"metadata.jsonl"
output_metadata_path
=
Path
(
args
.
dumpdir
)
/
"metadata.jsonl"
with
jsonlines
.
open
(
output_metadata_path
,
'w'
)
as
writer
:
with
jsonlines
.
open
(
output_metadata_path
,
'w'
)
as
writer
:
for
item
in
output_metadata
:
for
item
in
output_metadata
:
...
...
paddlespeech/t2s/exps/ernie_sat/preprocess.py
浏览文件 @
b9be2bd6
...
@@ -165,7 +165,7 @@ def process_sentences(config,
...
@@ -165,7 +165,7 @@ def process_sentences(config,
if
record
:
if
record
:
results
.
append
(
record
)
results
.
append
(
record
)
results
.
sort
(
key
=
itemgetter
(
"
utt_id
"
))
results
.
sort
(
key
=
itemgetter
(
"
speech_lengths
"
))
# replace 'w' with 'a' to write from the end of file
# replace 'w' with 'a' to write from the end of file
with
jsonlines
.
open
(
output_dir
/
"metadata.jsonl"
,
'a'
)
as
writer
:
with
jsonlines
.
open
(
output_dir
/
"metadata.jsonl"
,
'a'
)
as
writer
:
for
item
in
results
:
for
item
in
results
:
...
...
paddlespeech/t2s/exps/ernie_sat/train.py
浏览文件 @
b9be2bd6
...
@@ -31,6 +31,7 @@ from yacs.config import CfgNode
...
@@ -31,6 +31,7 @@ from yacs.config import CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_erniesat_collate_fn
from
paddlespeech.t2s.datasets.am_batch_fn
import
build_erniesat_collate_fn
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSAT
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSAT
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATEvaluator
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATEvaluator
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATUpdater
from
paddlespeech.t2s.models.ernie_sat
import
ErnieSATUpdater
...
@@ -86,7 +87,7 @@ def train_sp(args, config):
...
@@ -86,7 +87,7 @@ def train_sp(args, config):
seg_emb
=
config
.
model
[
'enc_input_layer'
]
==
'sega_mlm'
,
seg_emb
=
config
.
model
[
'enc_input_layer'
]
==
'sega_mlm'
,
text_masking
=
config
[
"model"
][
"text_masking"
])
text_masking
=
config
[
"model"
][
"text_masking"
])
train_sampler
=
DistributedBatch
Sampler
(
train_sampler
=
ErnieSAT
Sampler
(
train_dataset
,
train_dataset
,
batch_size
=
config
.
batch_size
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
...
...
paddlespeech/t2s/training/updaters/standard_updater.py
浏览文件 @
b9be2bd6
...
@@ -27,7 +27,7 @@ from timer import timer
...
@@ -27,7 +27,7 @@ from timer import timer
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
class
StandardUpdater
(
UpdaterBase
):
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
"""An example of over-simplification. Things may not be that simple, but
...
@@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase):
...
@@ -165,7 +165,8 @@ class StandardUpdater(UpdaterBase):
# NOTE: all batch sampler for distributed training should
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler
=
self
.
dataloader
.
batch_sampler
batch_sampler
=
self
.
dataloader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
)
\
or
isinstance
(
batch_sampler
,
ErnieSATSampler
):
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录