Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
c81a3f0f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c81a3f0f
编写于
12月 30, 2021
作者:
H
Hui Zhang
提交者:
GitHub
12月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242)
* batchsampler or distributebatchsampler * format
上级
6d93f3e5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
48 addition
and
23 deletion
+48
-23
paddlespeech/s2t/exps/u2_st/model.py
paddlespeech/s2t/exps/u2_st/model.py
+8
-4
paddlespeech/s2t/io/converter.py
paddlespeech/s2t/io/converter.py
+16
-11
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+19
-6
paddlespeech/t2s/exps/synthesize_e2e.py
paddlespeech/t2s/exps/synthesize_e2e.py
+5
-2
未找到文件。
paddlespeech/s2t/exps/u2_st/model.py
浏览文件 @
c81a3f0f
...
...
@@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
)
num_encs
=
1
,
dist_sampler
=
True
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
dev_manifest
,
...
...
@@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
)
num_encs
=
1
,
dist_sampler
=
True
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
# test dataset, return raw text
...
...
@@ -335,7 +337,8 @@ class U2STTrainer(Trainer):
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
num_encs
=
1
,
dist_sampler
=
False
)
logger
.
info
(
"Setup test Dataloader!"
)
...
...
@@ -542,7 +545,8 @@ class U2STTester(U2STTrainer):
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
logger
.
info
(
"RTF: %f, instance (%d), batch BELU = %f"
%
(
rtf
,
num_ins
,
bleu
))
logger
.
info
(
"RTF: %f, instance (%d), batch BELU = %f"
%
(
rtf
,
num_ins
,
bleu
))
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
msg
=
"Test: "
...
...
paddlespeech/s2t/io/converter.py
浏览文件 @
c81a3f0f
...
...
@@ -65,8 +65,9 @@ class CustomConverter():
# text data (output): (text_len, )
ys_data
.
append
(
ud
)
assert
xs_data
[
0
][
0
]
is
not
None
,
"please check Reader and Augmentation impl."
assert
xs_data
[
0
][
0
]
is
not
None
,
"please check Reader and Augmentation impl."
xs_pad
,
ilens
=
[],
[]
for
xs
in
xs_data
:
# perform subsampling
...
...
@@ -79,22 +80,26 @@ class CustomConverter():
# perform padding and convert to tensor
# currently only support real number
xs_pad
.
append
(
pad_list
(
xs
,
0
).
astype
(
self
.
dtype
))
if
not
self
.
load_aux_input
:
xs_pad
,
ilens
=
xs_pad
[
0
],
ilens
[
0
]
break
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad
,
olens
=
[],
[]
for
ys
in
ys_data
:
ys_pad
.
append
(
pad_list
(
[
np
.
array
(
y
[
0
][:])
if
isinstance
(
y
,
tuple
)
else
y
for
y
in
ys
],
self
.
ignore_id
))
ys_pad
.
append
(
pad_list
([
np
.
array
(
y
[
0
][:])
if
isinstance
(
y
,
tuple
)
else
y
for
y
in
ys
],
self
.
ignore_id
))
olens
.
append
(
np
.
array
([
y
[
0
].
shape
[
0
]
if
isinstance
(
y
,
tuple
)
else
y
.
shape
[
0
]
for
y
in
ys
]))
olens
.
append
(
np
.
array
(
[
y
[
0
].
shape
[
0
]
if
isinstance
(
y
,
tuple
)
else
y
.
shape
[
0
]
for
y
in
ys
]))
if
not
self
.
load_aux_output
:
ys_pad
,
olens
=
ys_pad
[
0
],
olens
[
0
]
break
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
c81a3f0f
...
...
@@ -18,6 +18,7 @@ from typing import Text
import
jsonlines
import
numpy
as
np
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
...
...
@@ -76,7 +77,8 @@ class BatchDataLoader():
subsampling_factor
:
int
=
1
,
load_aux_input
:
bool
=
False
,
load_aux_output
:
bool
=
False
,
num_encs
:
int
=
1
):
num_encs
:
int
=
1
,
dist_sampler
:
bool
=
False
):
self
.
json_file
=
json_file
self
.
train_mode
=
train_mode
self
.
use_sortagrad
=
sortagrad
==
-
1
or
sortagrad
>
0
...
...
@@ -94,6 +96,7 @@ class BatchDataLoader():
self
.
n_iter_processes
=
n_iter_processes
self
.
load_aux_input
=
load_aux_input
self
.
load_aux_output
=
load_aux_output
self
.
dist_sampler
=
dist_sampler
# read json data
with
jsonlines
.
open
(
json_file
,
'r'
)
as
reader
:
...
...
@@ -145,11 +148,18 @@ class BatchDataLoader():
self
.
dataset
=
TransformDataset
(
self
.
minibaches
,
self
.
converter
,
self
.
reader
)
self
.
sampler
=
DistributedBatchSampler
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
)
if
self
.
dist_sampler
:
self
.
sampler
=
DistributedBatchSampler
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
drop_last
=
False
,
)
else
:
self
.
sampler
=
BatchSampler
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
drop_last
=
False
,
)
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
...
...
@@ -181,5 +191,8 @@ class BatchDataLoader():
echo
+=
f
"subsampling_factor:
{
self
.
subsampling_factor
}
, "
echo
+=
f
"num_encs:
{
self
.
num_encs
}
, "
echo
+=
f
"num_workers:
{
self
.
n_iter_processes
}
, "
echo
+=
f
"load_aux_input:
{
self
.
load_aux_input
}
, "
echo
+=
f
"load_aux_output:
{
self
.
load_aux_output
}
, "
echo
+=
f
"dist_sampler:
{
self
.
dist_sampler
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
paddlespeech/t2s/exps/synthesize_e2e.py
浏览文件 @
c81a3f0f
...
...
@@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids
=
True
if
args
.
lang
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
sentence
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
elif
args
.
lang
==
'en'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录