Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7668f614
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7668f614
编写于
3月 03, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sid dataloader for training, test=doc
上级
6af2bc3d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
50 addition
and
7 deletion
+50
-7
examples/voxceleb/sv0/local/train.py
examples/voxceleb/sv0/local/train.py
+30
-7
paddlespeech/vector/datasets/batch.py
paddlespeech/vector/datasets/batch.py
+20
-0
未找到文件。
examples/voxceleb/sv0/local/train.py
浏览文件 @
7668f614
...
...
@@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
os
import
paddle
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddlespeech.vector.datasets.batch
import
waveform_collate_fn
from
paddlespeech.vector.layers.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.layers.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.layers.lr
import
CyclicLRScheduler
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.training.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.layers.loss
import
AdditiveAngularMargin
,
LogSoftmaxWrapper
def
main
(
args
):
# stage0: set the training device, cpu or gpu
...
...
@@ -61,7 +66,6 @@ def main(args):
criterion
=
LogSoftmaxWrapper
(
loss_fn
=
AdditiveAngularMargin
(
margin
=
0.2
,
scale
=
30
))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch
=
0
...
...
@@ -89,7 +93,19 @@ def main(args):
print
(
f
'Restore training from epoch
{
start_epoch
}
.'
)
except
ValueError
:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler
=
DistributedBatchSampler
(
train_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
DataLoader
(
train_ds
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
collate_fn
=
waveform_collate_fn
,
return_list
=
True
,
use_buffer_reader
=
True
,
)
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
...
...
@@ -105,10 +121,17 @@ if __name__ == "__main__":
type
=
float
,
default
=
1e-8
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
64
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of workers in dataloader."
)
args
=
parser
.
parse_args
()
# yapf: enable
...
...
paddlespeech/vector/datasets/batch.py
0 → 100644
浏览文件 @
7668f614
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
waveform_collate_fn
(
batch
):
waveforms
=
np
.
stack
([
item
[
'feat'
]
for
item
in
batch
])
labels
=
np
.
stack
([
item
[
'label'
]
for
item
in
batch
])
return
{
'waveforms'
:
waveforms
,
'labels'
:
labels
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录