Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d28ccfa9
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d28ccfa9
编写于
3月 21, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vector cli component, test=doc
上级
506d26a9
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
712 addition
and
50 deletion
+712
-50
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+2
-2
examples/voxceleb/sv0/local/data.sh
examples/voxceleb/sv0/local/data.sh
+24
-3
examples/voxceleb/sv0/local/emb.sh
examples/voxceleb/sv0/local/emb.sh
+44
-6
examples/voxceleb/sv0/local/test.sh
examples/voxceleb/sv0/local/test.sh
+38
-4
examples/voxceleb/sv0/local/train.sh
examples/voxceleb/sv0/local/train.sh
+44
-5
examples/voxceleb/sv0/run.sh
examples/voxceleb/sv0/run.sh
+5
-7
paddleaudio/paddleaudio/metric/eer.py
paddleaudio/paddleaudio/metric/eer.py
+10
-4
paddlespeech/cli/__init__.py
paddlespeech/cli/__init__.py
+1
-0
paddlespeech/cli/vector/__init__.py
paddlespeech/cli/vector/__init__.py
+14
-0
paddlespeech/cli/vector/infer.py
paddlespeech/cli/vector/infer.py
+345
-0
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
+3
-3
paddlespeech/vector/exps/ecapa_tdnn/test.py
paddlespeech/vector/exps/ecapa_tdnn/test.py
+3
-4
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+39
-10
paddlespeech/vector/io/batch.py
paddlespeech/vector/io/batch.py
+90
-2
paddlespeech/vector/modules/loss.py
paddlespeech/vector/modules/loss.py
+24
-0
paddlespeech/vector/modules/sid_model.py
paddlespeech/vector/modules/sid_model.py
+26
-0
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
d28ccfa9
...
@@ -14,10 +14,10 @@ random_chunk: True
...
@@ -14,10 +14,10 @@ random_chunk: True
# FEATURE EXTRACTION SETTING #
# FEATURE EXTRACTION SETTING #
###########################################################
###########################################################
# currently, we only support fbank
# currently, we only support fbank
s
ample_rate
:
16000
s
r
:
16000
# sample rate
n_mels
:
80
n_mels
:
80
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_
length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
hop_
size
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
###########################################################
# MODEL SETTING #
# MODEL SETTING #
...
...
examples/voxceleb/sv0/local/data.sh
浏览文件 @
d28ccfa9
#!/bin/bash
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
stage
=
-1
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage
=
0
stop_stage
=
100
stop_stage
=
100
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
if
[
$#
-ne
2
]
;
then
echo
"Usage:
$0
[options] <data-dir> <conf-path>"
;
echo
"e.g.:
$0
./data/ conf/ecapa_tdnn.yaml"
echo
"Options: "
echo
" --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo
" --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit
1
;
fi
dir
=
$1
dir
=
$1
conf_path
=
$2
conf_path
=
$2
mkdir
-p
${
dir
}
mkdir
-p
${
dir
}
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3
local
/data_prepare.py
\
python3
local
/data_prepare.py
\
...
...
examples/voxceleb/sv0/local/emb.sh
浏览文件 @
d28ccfa9
#!/bin/bash
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.
./path.sh
.
./path.sh
exp_dir
=
exp/ecapa-tdnn-vox12-big//epoch_10/
# experiment directory
stage
=
0
stop_stage
=
100
exp_dir
=
exp/ecapa-tdnn-vox12-big/
# experiment directory
conf_path
=
conf/ecapa_tdnn.yaml
conf_path
=
conf/ecapa_tdnn.yaml
audio_path
=
"demo/voxceleb/00001.wav"
audio_path
=
"demo/voxceleb/00001.wav"
use_gpu
=
true
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
if
[
$#
-ne
0
]
;
then
echo
"Usage:
$0
[options]"
;
echo
"e.g.:
$0
./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo
"Options: "
echo
" --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo
" --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo
" --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
echo
" --exp-dir # experiment directorh, where is has the model.pdparams"
echo
" --conf-path # configuration file for extracting the embedding"
echo
" --audio-path # audio-path, which will be processed to extract the embedding"
exit
1
;
fi
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
# set the test device
device
=
"cpu"
if
${
use_gpu
}
;
then
device
=
"gpu"
fi
# extract the audio embedding
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
python3
${
BIN_DIR
}
/extract_emb.py
--device
"gpu"
\
# extract the audio embedding
--config
${
conf_path
}
\
python3
${
BIN_DIR
}
/extract_emb.py
--device
${
device
}
\
--audio-path
${
audio_path
}
--load-checkpoint
${
exp_dir
}
--config
${
conf_path
}
\
\ No newline at end of file
--audio-path
${
audio_path
}
--load-checkpoint
${
exp_dir
}
fi
\ No newline at end of file
examples/voxceleb/sv0/local/test.sh
浏览文件 @
d28ccfa9
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage
=
1
stop_stage
=
100
use_gpu
=
true
# if true, we run on GPU.
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
if
[
$#
-ne
3
]
;
then
echo
"Usage:
$0
[options] <data-dir> <exp-dir> <conf-path>"
;
echo
"e.g.:
$0
./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo
"Options: "
echo
" --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo
" --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo
" --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit
1
;
fi
dir
=
$1
dir
=
$1
exp_dir
=
$2
exp_dir
=
$2
conf_path
=
$3
conf_path
=
$3
python3
${
BIN_DIR
}
/test.py
\
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
--config
${
conf_path
}
\
# test the model and compute the eer metrics
--data-dir
${
dir
}
\
python3
${
BIN_DIR
}
/test.py
\
--load-checkpoint
${
exp_dir
}
--data-dir
${
dir
}
\
\ No newline at end of file
--load-checkpoint
${
exp_dir
}
\
--config
${
conf_path
}
fi
examples/voxceleb/sv0/local/train.sh
浏览文件 @
d28ccfa9
#!/bin/bash
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage
=
0
stop_stage
=
100
use_gpu
=
true
# if true, we run on GPU.
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
if
[
$#
-ne
3
]
;
then
echo
"Usage:
$0
[options] <data-dir> <exp-dir> <conf-path>"
;
echo
"e.g.:
$0
./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo
"Options: "
echo
" --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo
" --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo
" --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit
1
;
fi
dir
=
$1
dir
=
$1
exp_dir
=
$2
exp_dir
=
$2
conf_path
=
$3
conf_path
=
$3
# get the gpu nums for training
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
echo
"using
$ngpu
gpus..."
# train the speaker identification task with voxceleb data
# setting training device
# Note: we will store the log file in exp/log directory
device
=
"cpu"
python3
-m
paddle.distributed.launch
--gpus
=
$CUDA_VISIBLE_DEVICES
\
if
${
use_gpu
}
;
then
${
BIN_DIR
}
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
device
=
"gpu"
--data-dir
${
dir
}
--config
${
conf_path
}
fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train the speaker identification task with voxceleb data
# and we will create the trained model parameters in ${exp_dir}/model.pdparams as the soft link
# Note: we will store the log file in exp/log directory
python3
-m
paddle.distributed.launch
--gpus
=
$CUDA_VISIBLE_DEVICES
\
${
BIN_DIR
}
/train.py
--device
${
device
}
--checkpoint-dir
${
exp_dir
}
\
--data-dir
${
dir
}
--config
${
conf_path
}
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/voxceleb/sv0/run.sh
浏览文件 @
d28ccfa9
...
@@ -36,11 +36,10 @@ stop_stage=50
...
@@ -36,11 +36,10 @@ stop_stage=50
# data directory
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
# vox2 wav path, we must convert the m4a format to wav format
# dir=data-demo/ # data info directory
dir
=
data/
# data info directory
dir
=
demo/
# data info directory
exp_dir
=
exp/ecapa-tdnn-vox12-big/
/
# experiment directory
exp_dir
=
exp/ecapa-tdnn-vox12-big/
# experiment directory
conf_path
=
conf/ecapa_tdnn.yaml
conf_path
=
conf/ecapa_tdnn.yaml
gpus
=
0,1,2,3
gpus
=
0,1,2,3
...
@@ -50,16 +49,15 @@ mkdir -p ${exp_dir}
...
@@ -50,16 +49,15 @@ mkdir -p ${exp_dir}
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# and we should specifiy the vox2 data in the data.sh
bash ./local/data.sh
${
dir
}
${
conf_path
}
||
exit
-1
;
bash ./local/data.sh
${
dir
}
${
conf_path
}
||
exit
-1
;
fi
fi
if
[
$stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
$stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# stage 1: train the speaker identification model
# stage 1: train the speaker identification model
CUDA_VISIBLE_DEVICES
=
${
gpus
}
bash ./local/train.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
bash ./local/train.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
fi
if
[
$stage
-le
2
]
;
then
if
[
$stage
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# stage 2: get the speaker verification scores with cosine function
# stage 2: get the speaker verification scores with cosine function
# now we only support use cosine to get the scores
# now we only support use cosine to get the scores
CUDA_VISIBLE_DEVICES
=
0 bash ./local/test.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
CUDA_VISIBLE_DEVICES
=
0 bash ./local/test.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
...
...
paddleaudio/paddleaudio/metric/eer.py
浏览文件 @
d28ccfa9
...
@@ -19,9 +19,15 @@ from sklearn.metrics import roc_curve
...
@@ -19,9 +19,15 @@ from sklearn.metrics import roc_curve
def
compute_eer
(
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
)
->
List
[
float
]:
def
compute_eer
(
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
)
->
List
[
float
]:
'''
"""Compute EER and return score threshold.
Compute EER and return score threshold.
'''
Args:
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
Returns:
List[float]: eer and the specific threshold
"""
fpr
,
tpr
,
threshold
=
roc_curve
(
y_true
=
labels
,
y_score
=
scores
)
fpr
,
tpr
,
threshold
=
roc_curve
(
y_true
=
labels
,
y_score
=
scores
)
fnr
=
1
-
tpr
fnr
=
1
-
tpr
eer_threshold
=
threshold
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
eer_threshold
=
threshold
[
np
.
nanargmin
(
np
.
absolute
((
fnr
-
fpr
)))]
...
@@ -54,7 +60,7 @@ def compute_minDCF(positive_scores,
...
@@ -54,7 +60,7 @@ def compute_minDCF(positive_scores,
p_target (float, optional): Prior probability of having a target (default 0.01).
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
Returns:
_type_: min dcf
List[float]: min dcf and the specific threshold
"""
"""
# Computing candidate thresholds
# Computing candidate thresholds
if
len
(
positive_scores
.
shape
)
>
1
:
if
len
(
positive_scores
.
shape
)
>
1
:
...
...
paddlespeech/cli/__init__.py
浏览文件 @
d28ccfa9
...
@@ -21,5 +21,6 @@ from .st import STExecutor
...
@@ -21,5 +21,6 @@ from .st import STExecutor
from
.stats
import
StatsExecutor
from
.stats
import
StatsExecutor
from
.text
import
TextExecutor
from
.text
import
TextExecutor
from
.tts
import
TTSExecutor
from
.tts
import
TTSExecutor
from
.vector
import
VectorExecutor
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
paddlespeech/cli/vector/__init__.py
0 → 100644
浏览文件 @
d28ccfa9
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.infer
import
VectorExecutor
paddlespeech/cli/vector/infer.py
0 → 100644
浏览文件 @
d28ccfa9
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Union
import
librosa
import
numpy
as
np
import
paddle
import
soundfile
from
yacs.config
import
CfgNode
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.compliance.librosa
import
melspectrogram
from
..download
import
get_path_from_url
from
..executor
import
BaseExecutor
from
..log
import
logger
from
..utils
import
cli_register
from
..utils
import
download_and_decompress
from
..utils
import
MODEL_HOME
from
..utils
import
stats_wrapper
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
pretrained_models
=
{
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ecapa_tdnn-16k"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz'
,
'md5'
:
'76cb19ed857e6623856b7cd7ebbfeda4'
,
'cfg_path'
:
'model.yaml'
,
'ckpt_path'
:
'exp/conformer/checkpoints/wenetspeech'
,
},
}
model_alias
=
{
"ecapa_tdnn"
:
"paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn"
,
}
@
cli_register
(
name
=
"paddlespeech.vector"
,
description
=
"Speech to vector embedding infer command."
)
class
VectorExecutor
(
BaseExecutor
):
def
__init__
(
self
):
super
(
VectorExecutor
,
self
).
__init__
()
self
.
parser
=
argparse
.
ArgumentParser
(
prog
=
"paddlespeech.vector"
,
add_help
=
True
)
self
.
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"ecapa_tdnn-voxceleb12"
,
choices
=
[
"ecapa_tdnn"
],
help
=
"Choose model type of asr task."
)
self
.
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"spk"
,
choices
=
[
"spk"
],
help
=
"task type in vector domain"
)
self
.
parser
.
add_argument
(
"--input"
,
type
=
str
,
default
=
None
,
help
=
"Audio file to recognize."
)
self
.
parser
.
add_argument
(
"--sample_rate"
,
type
=
int
,
default
=
16000
,
choices
=
[
16000
,
8000
],
help
=
"Choose the audio sample rate of the model. 8000 or 16000"
)
self
.
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
None
,
help
=
"Checkpoint file of model."
)
self
.
parser
.
add_argument
(
'--config'
,
type
=
str
,
default
=
None
,
help
=
'Config of asr task. Use deault config when it is None.'
)
self
.
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
paddle
.
get_device
(),
help
=
"Choose device to execute model inference."
)
self
.
parser
.
add_argument
(
'-d'
,
'--job_dump_result'
,
action
=
'store_true'
,
help
=
'Save job result into file.'
)
self
.
parser
.
add_argument
(
'-v'
,
'--verbose'
,
action
=
'store_true'
,
help
=
'Increase logger verbosity of current task.'
)
def
execute
(
self
,
argv
:
List
[
str
])
->
bool
:
"""Command line entry for vector model
Args:
argv (List[str]): command line args list
Returns:
bool:
False: some audio occurs error
True: all audio process success
"""
# stage 0: parse the args and get the required args
parser_args
=
self
.
parser
.
parse_args
(
argv
)
model
=
parser_args
.
model
sample_rate
=
parser_args
.
sample_rate
config
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
device
=
parser_args
.
device
# stage 1: configurate the verbose flag
if
not
parser_args
.
verbose
:
self
.
disable_task_loggers
()
# stage 2: read the input data and store them as a list
task_source
=
self
.
get_task_source
(
parser_args
.
input
)
logger
.
info
(
f
"task source:
{
task_source
}
"
)
# stage 3: process the audio one by one
task_result
=
OrderedDict
()
has_exceptions
=
False
for
id_
,
input_
in
task_source
.
items
():
try
:
res
=
self
(
input_
,
model
,
sample_rate
,
config
,
ckpt_path
,
device
)
task_result
[
id_
]
=
res
except
Exception
as
e
:
has_exceptions
=
True
task_result
[
id_
]
=
f
'
{
e
.
__class__
.
__name__
}
:
{
e
}
'
logger
.
info
(
"task result as follows: "
)
logger
.
info
(
f
"
{
task_result
}
"
)
# stage 4: process the all the task results
self
.
process_task_results
(
parser_args
.
input
,
task_result
,
parser_args
.
job_dump_result
)
# stage 5: return the exception flag
# if return False, somen audio process occurs error
if
has_exceptions
:
return
False
else
:
return
True
@
stats_wrapper
def
__call__
(
self
,
audio_file
:
os
.
PathLike
,
model
:
str
=
'ecapa_tdnn-voxceleb12'
,
sample_rate
:
int
=
16000
,
config
:
os
.
PathLike
=
None
,
ckpt_path
:
os
.
PathLike
=
None
,
force_yes
:
bool
=
False
,
device
=
paddle
.
get_device
()):
audio_file
=
os
.
path
.
abspath
(
audio_file
)
if
not
self
.
_check
(
audio_file
,
sample_rate
):
sys
.
exit
(
-
1
)
logger
.
info
(
f
"device type:
{
device
}
"
)
paddle
.
device
.
set_device
(
device
)
self
.
_init_from_path
(
model
,
sample_rate
,
config
,
ckpt_path
)
self
.
preprocess
(
model
,
audio_file
)
self
.
infer
(
model
)
res
=
self
.
postprocess
()
return
res
def
_get_pretrained_path
(
self
,
tag
:
str
)
->
os
.
PathLike
:
support_models
=
list
(
pretrained_models
.
keys
())
assert
tag
in
pretrained_models
,
\
'The model "{}" you want to use has not been supported,
\
please choose other models.
\n
\
The support models includes
\n\t\t
{}'
.
format
(
tag
,
"
\n\t\t
"
.
join
(
support_models
))
res_path
=
os
.
path
.
join
(
MODEL_HOME
,
tag
)
def
_init_from_path
(
self
,
model_type
:
str
=
'ecapa_tdnn-voxceleb12'
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
):
if
hasattr
(
self
,
"model"
):
logger
.
info
(
"Model has been initialized"
)
return
# stage 1: get the model and config path
if
cfg_path
is
None
or
ckpt_path
is
None
:
sample_rate_str
=
"16k"
if
sample_rate
==
16000
else
"8k"
tag
=
model_type
+
"-"
+
sample_rate_str
res_path
=
self
.
_get_pretrained_path
(
tag
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
f
"start to read the ckpt from
{
self
.
ckpt_path
}
"
)
logger
.
info
(
f
"read the config from
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
"get the res path
{
self
.
res_path
}
"
)
# stage 2: read and config and init the model body
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
# stage 3: get the model name to instance the model network with dynamic_import
# Noet: we use the '-' to get the model name instead of '_'
logger
.
info
(
"start to dynamic import the model class"
)
model_name
=
model_type
[:
model_type
.
rindex
(
'-'
)]
logger
.
info
(
f
"model name
{
model_name
}
"
)
model_class
=
dynamic_import
(
model_name
,
model_alias
)
model_conf
=
self
.
config
.
model
backbone
=
model_class
(
**
model_conf
)
model
=
SpeakerIdetification
(
backbone
=
backbone
,
num_class
=
self
.
config
.
num_speakers
)
self
.
model
=
model
self
.
model
.
eval
()
# stage 4: load the model parameters
logger
.
info
(
"start to set the model parameters to model"
)
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
self
.
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
"create the model instance success"
)
@
paddle
.
no_grad
()
def
infer
(
self
,
model_type
:
str
):
feats
=
self
.
_inputs
[
"feats"
]
lengths
=
self
.
_inputs
[
"lengths"
]
logger
.
info
(
f
"start to do backbone network model forward"
)
logger
.
info
(
f
"feats shape:
{
feats
.
shape
}
, lengths shape:
{
lengths
.
shape
}
"
)
# embedding from (1, emb_size, 1) -> (emb_size)
embedding
=
self
.
model
.
backbone
(
feats
,
lengths
).
squeeze
().
numpy
()
logger
.
info
(
f
"embedding size:
{
embedding
.
shape
}
"
)
self
.
_outputs
[
"embedding"
]
=
embedding
def
postprocess
(
self
)
->
Union
[
str
,
os
.
PathLike
]:
return
self
.
_outputs
[
"embedding"
]
def
preprocess
(
self
,
model_type
:
str
,
input_file
:
Union
[
str
,
os
.
PathLike
]):
audio_file
=
input_file
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
logger
.
info
(
f
"Preprocess audio file:
{
audio_file
}
"
)
# stage 1: load the audio
waveform
,
sr
=
load_audio
(
audio_file
)
logger
.
info
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
# stage 2: get the audio feat
try
:
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
self
.
config
.
sr
,
n_mels
=
self
.
config
.
n_mels
,
window_size
=
self
.
config
.
window_size
,
hop_length
=
self
.
config
.
hop_size
)
logger
.
info
(
f
"extract the audio feat, shape is:
{
feat
.
shape
}
"
)
except
Exception
as
e
:
logger
.
info
(
f
"feat occurs exception
{
e
}
"
)
sys
.
exit
(
-
1
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
lengths
=
paddle
.
ones
([
1
])
feat
=
feature_normalize
(
feat
,
mean_norm
=
True
,
std_norm
=
False
)
logger
.
info
(
f
"feats shape:
{
feat
.
shape
}
"
)
self
.
_inputs
[
"feats"
]
=
feat
self
.
_inputs
[
"lengths"
]
=
lengths
logger
.
info
(
"audio extract the feat success"
)
def
_check
(
self
,
audio_file
:
str
,
sample_rate
:
int
):
self
.
sample_rate
=
sample_rate
if
self
.
sample_rate
!=
16000
and
self
.
sample_rate
!=
8000
:
logger
.
error
(
"invalid sample rate, please input --sr 8000 or --sr 16000"
)
return
False
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
not
os
.
path
.
isfile
(
audio_file
):
logger
.
error
(
"Please input the right audio file path"
)
return
False
logger
.
info
(
"checking the aduio file format......"
)
try
:
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"float32"
,
always_2d
=
True
)
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
error
(
"can not open the audio file, please check the audio file format is 'wav'.
\n
\
you can try to use sox to change the file format.
\n
\
For example:
\n
\
sample rate: 16k
\n
\
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav
\n
\
sample rate: 8k
\n
\
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
"
)
return
False
logger
.
info
(
f
"The sample rate is
{
audio_sample_rate
}
"
)
if
audio_sample_rate
!=
self
.
sample_rate
:
logger
.
error
(
"The sample rate of the input file is not {}.
\n
\
The program will resample the wav file to {}.
\n
\
If the result does not meet your expectations,
\n
\
Please input the 16k 16 bit 1 channel wav file.
\
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
sys
.
exit
(
-
1
)
else
:
logger
.
info
(
"The audio file format is right"
)
return
True
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
浏览文件 @
d28ccfa9
...
@@ -63,16 +63,16 @@ def extract_audio_embedding(args, config):
...
@@ -63,16 +63,16 @@ def extract_audio_embedding(args, config):
# so the final shape is [1, dim, time]
# so the final shape is [1, dim, time]
start_time
=
time
.
time
()
start_time
=
time
.
time
()
feat
=
melspectrogram
(
x
=
waveform
,
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
s
ample_rate
,
sr
=
config
.
s
r
,
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_
length
)
hop_length
=
config
.
hop_
size
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
# in inference period, the lengths is all one without padding
lengths
=
paddle
.
ones
([
1
])
lengths
=
paddle
.
ones
([
1
])
feat
=
feature_normalize
(
feat
=
feature_normalize
(
feat
,
mean_norm
=
True
,
std_norm
=
False
,
convert_to_numpy
=
True
)
feat
,
mean_norm
=
True
,
std_norm
=
False
)
# model backbone network forward the feats and get the embedding
# model backbone network forward the feats and get the embedding
embedding
=
model
.
backbone
(
embedding
=
model
.
backbone
(
...
...
paddlespeech/vector/exps/ecapa_tdnn/test.py
浏览文件 @
d28ccfa9
...
@@ -49,8 +49,6 @@ def main(args, config):
...
@@ -49,8 +49,6 @@ def main(args, config):
# stage3: load the pre-trained model
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
# we get the last model from the epoch and save_interval
last_save_epoch
=
(
config
.
epochs
//
config
.
save_interval
)
*
config
.
save_interval
args
.
load_checkpoint
=
os
.
path
.
join
(
args
.
load_checkpoint
,
"epoch_"
+
str
(
last_save_epoch
))
args
.
load_checkpoint
=
os
.
path
.
abspath
(
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
...
@@ -61,6 +59,7 @@ def main(args, config):
...
@@ -61,6 +59,7 @@ def main(args, config):
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage4: construct the enroll and test dataloader
# stage4: construct the enroll and test dataloader
enroll_dataset
=
VoxCeleb
(
enroll_dataset
=
VoxCeleb
(
subset
=
'enroll'
,
subset
=
'enroll'
,
target_dir
=
args
.
data_dir
,
target_dir
=
args
.
data_dir
,
...
@@ -68,7 +67,7 @@ def main(args, config):
...
@@ -68,7 +67,7 @@ def main(args, config):
random_chunk
=
False
,
random_chunk
=
False
,
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_
length
)
hop_length
=
config
.
hop_
size
)
enroll_sampler
=
BatchSampler
(
enroll_sampler
=
BatchSampler
(
enroll_dataset
,
batch_size
=
config
.
batch_size
,
enroll_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
# Shuffle to make embedding normalization more robust.
shuffle
=
True
)
# Shuffle to make embedding normalization more robust.
...
@@ -85,7 +84,7 @@ def main(args, config):
...
@@ -85,7 +84,7 @@ def main(args, config):
random_chunk
=
False
,
random_chunk
=
False
,
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_
length
)
hop_length
=
config
.
hop_
size
)
test_sampler
=
BatchSampler
(
test_sampler
=
BatchSampler
(
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
d28ccfa9
...
@@ -15,6 +15,7 @@ import argparse
...
@@ -15,6 +15,7 @@ import argparse
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
time
import
paddle
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
...
@@ -35,6 +36,7 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
...
@@ -35,6 +36,7 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from
paddlespeech.vector.training.scheduler
import
CyclicLRScheduler
from
paddlespeech.vector.training.scheduler
import
CyclicLRScheduler
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.time
import
Timer
from
paddlespeech.vector.utils.time
import
Timer
from
paddlespeech.vector.io.batch
import
batch_pad_right
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -55,7 +57,7 @@ def main(args, config):
...
@@ -55,7 +57,7 @@ def main(args, config):
train_dataset
=
VoxCeleb
(
'train'
,
target_dir
=
args
.
data_dir
)
train_dataset
=
VoxCeleb
(
'train'
,
target_dir
=
args
.
data_dir
)
dev_dataset
=
VoxCeleb
(
'dev'
,
target_dir
=
args
.
data_dir
)
dev_dataset
=
VoxCeleb
(
'dev'
,
target_dir
=
args
.
data_dir
)
if
args
.
augment
:
if
config
.
augment
:
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
else
:
else
:
augment_pipeline
=
[]
augment_pipeline
=
[]
...
@@ -126,6 +128,7 @@ def main(args, config):
...
@@ -126,6 +128,7 @@ def main(args, config):
# we will comment the training process
# we will comment the training process
steps_per_epoch
=
len
(
train_sampler
)
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
config
.
epochs
)
timer
=
Timer
(
steps_per_epoch
*
config
.
epochs
)
last_saved_epoch
=
""
timer
.
start
()
timer
.
start
()
for
epoch
in
range
(
start_epoch
+
1
,
config
.
epochs
+
1
):
for
epoch
in
range
(
start_epoch
+
1
,
config
.
epochs
+
1
):
...
@@ -135,9 +138,19 @@ def main(args, config):
...
@@ -135,9 +138,19 @@ def main(args, config):
avg_loss
=
0
avg_loss
=
0
num_corrects
=
0
num_corrects
=
0
num_samples
=
0
num_samples
=
0
train_reader_cost
=
0.0
train_feat_cost
=
0.0
train_run_cost
=
0.0
reader_start
=
time
.
time
()
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
# stage 9-1: batch data is audio sample points and speaker id label
# stage 9-1: batch data is audio sample points and speaker id label
feat_start
=
time
.
time
()
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
waveforms
,
lengths
=
batch_pad_right
(
waveforms
.
numpy
())
waveforms
=
paddle
.
to_tensor
(
waveforms
)
# stage 9-2: audio sample augment method, which is done on the audio sample point
# stage 9-2: audio sample augment method, which is done on the audio sample point
# the original wavefrom and the augmented waveform is concatented in a batch
# the original wavefrom and the augmented waveform is concatented in a batch
...
@@ -153,18 +166,20 @@ def main(args, config):
...
@@ -153,18 +166,20 @@ def main(args, config):
feats
=
[]
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
s
ample_rate
,
sr
=
config
.
s
r
,
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_
length
)
hop_length
=
config
.
hop_
size
)
feats
.
append
(
feat
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
# stage 9-4: feature normalize, which help converge and imporve the performance
# stage 9-4: feature normalize, which help converge and imporve the performance
feats
=
feature_normalize
(
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
# Features normalization
feats
,
mean_norm
=
True
,
std_norm
=
False
)
# Features normalization
train_feat_cost
+=
time
.
time
()
-
feat_start
# stage 9-5: model forward, such ecapa-tdnn, x-vector
# stage 9-5: model forward, such ecapa-tdnn, x-vector
train_start
=
time
.
time
()
logits
=
model
(
feats
)
logits
=
model
(
feats
)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
...
@@ -177,6 +192,7 @@ def main(args, config):
...
@@ -177,6 +192,7 @@ def main(args, config):
paddle
.
optimizer
.
lr
.
LRScheduler
):
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
optimizer
.
clear_grad
()
train_run_cost
+=
time
.
time
()
-
train_start
# stage 9-8: Calculate average loss per batch
# stage 9-8: Calculate average loss per batch
avg_loss
+=
loss
.
numpy
()[
0
]
avg_loss
+=
loss
.
numpy
()[
0
]
...
@@ -186,7 +202,7 @@ def main(args, config):
...
@@ -186,7 +202,7 @@ def main(args, config):
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
# step plus one in timer
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if
(
batch_idx
+
1
)
%
config
.
log_interval
==
0
and
local_rank
==
0
:
if
(
batch_idx
+
1
)
%
config
.
log_interval
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
...
@@ -197,6 +213,9 @@ def main(args, config):
...
@@ -197,6 +213,9 @@ def main(args, config):
epoch
,
config
.
epochs
,
batch_idx
+
1
,
steps_per_epoch
)
epoch
,
config
.
epochs
,
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' avg_reader_cost: {:.5f} sec,'
.
format
(
train_reader_cost
/
config
.
log_interval
)
print_msg
+=
' avg_feat_cost: {:.5f} sec,'
.
format
(
train_feat_cost
/
config
.
log_interval
)
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
train_run_cost
/
config
.
log_interval
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
info
(
print_msg
)
logger
.
info
(
print_msg
)
...
@@ -204,6 +223,11 @@ def main(args, config):
...
@@ -204,6 +223,11 @@ def main(args, config):
avg_loss
=
0
avg_loss
=
0
num_corrects
=
0
num_corrects
=
0
num_samples
=
0
num_samples
=
0
train_reader_cost
=
0.0
train_feat_cost
=
0.0
train_run_cost
=
0.0
reader_start
=
time
.
time
()
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if
epoch
%
config
.
save_interval
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
epoch
%
config
.
save_interval
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
...
@@ -239,10 +263,10 @@ def main(args, config):
...
@@ -239,10 +263,10 @@ def main(args, config):
feats
=
[]
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
s
ample_rate
,
sr
=
config
.
s
r
,
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_
length
)
hop_length
=
config
.
hop_
size
)
feats
.
append
(
feat
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
...
@@ -261,6 +285,7 @@ def main(args, config):
...
@@ -261,6 +285,7 @@ def main(args, config):
# stage 9-14: Save model parameters
# stage 9-14: Save model parameters
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'epoch_{}'
.
format
(
epoch
))
'epoch_{}'
.
format
(
epoch
))
last_saved_epoch
=
os
.
path
.
join
(
'epoch_{}'
.
format
(
epoch
),
"model.pdparams"
)
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
...
@@ -270,6 +295,14 @@ def main(args, config):
...
@@ -270,6 +295,14 @@ def main(args, config):
if
nranks
>
1
:
if
nranks
>
1
:
paddle
.
distributed
.
barrier
()
# Main process
paddle
.
distributed
.
barrier
()
# Main process
# stage 10: create the final trained model.pdparams with soft link
if
local_rank
==
0
:
final_model
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
"model.pdparams"
)
logger
.
info
(
f
"we will create the final model:
{
final_model
}
"
)
if
os
.
path
.
islink
(
final_model
):
logger
.
info
(
f
"An
{
final_model
}
already exists, we will rm is and create it again"
)
os
.
unlink
(
final_model
)
os
.
symlink
(
last_saved_epoch
,
final_model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# yapf: disable
# yapf: disable
...
@@ -294,10 +327,6 @@ if __name__ == "__main__":
...
@@ -294,10 +327,6 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
default
=
'./checkpoint'
,
default
=
'./checkpoint'
,
help
=
"Directory to save model checkpoints."
)
help
=
"Directory to save model checkpoints."
)
parser
.
add_argument
(
"--augment"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Apply audio augments."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# yapf: enable
# yapf: enable
...
...
paddlespeech/vector/io/batch.py
浏览文件 @
d28ccfa9
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
numpy
def
waveform_collate_fn
(
batch
):
def
waveform_collate_fn
(
batch
):
waveforms
=
np
.
stack
([
item
[
'feat'
]
for
item
in
batch
])
waveforms
=
np
.
stack
([
item
[
'feat'
]
for
item
in
batch
])
...
@@ -80,4 +80,92 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
...
@@ -80,4 +80,92 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
# we convert the original length of each utterance to the ratio of the max length
# we convert the original length of each utterance to the ratio of the max length
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
\ No newline at end of file
def
pad_right_to
(
array
,
target_shape
,
mode
=
"constant"
,
value
=
0
):
"""
This function takes a numpy array of arbitrary shape and pads it to target
shape by appending values on the right.
Args:
array: input numpy array. Input array whose dimension we need to pad.
target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim
mode : str. Pad mode, please refer to numpy.pad documentation.
value : float. Pad value, please refer to numpy.pad documentation.
Returns:
array: numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
assert
len
(
target_shape
)
==
array
.
ndim
pads
=
[]
# this contains the abs length of the padding for each dimension.
valid_vals
=
[]
# thic contains the relative lengths for each dimension.
i
=
0
# iterating over target_shape ndims
while
i
<
len
(
target_shape
):
assert
(
target_shape
[
i
]
>=
array
.
shape
[
i
]
),
"Target shape must be >= original shape for every dim"
pads
.
append
([
0
,
target_shape
[
i
]
-
array
.
shape
[
i
]])
valid_vals
.
append
(
array
.
shape
[
i
]
/
target_shape
[
i
])
i
+=
1
array
=
numpy
.
pad
(
array
,
pads
,
mode
=
mode
,
constant_values
=
value
)
return
array
,
valid_vals
def
batch_pad_right
(
arrays
,
mode
=
"constant"
,
value
=
0
):
"""Given a list of numpy arrays it batches them together by padding to the right
on each dimension in order to get same length for all.
Args:
arrays : list. List of array we wish to pad together.
mode : str. Padding mode see numpy.pad documentation.
value : float. Padding value see numpy.pad documentation.
Returns:
array : numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
if
not
len
(
arrays
):
raise
IndexError
(
"arrays list must not be empty"
)
if
len
(
arrays
)
==
1
:
# if there is only one array in the batch we simply unsqueeze it.
return
numpy
.
expand_dims
(
arrays
[
0
],
axis
=
0
),
numpy
.
array
([
1.0
])
if
not
(
any
(
[
arrays
[
i
].
ndim
==
arrays
[
0
].
ndim
for
i
in
range
(
1
,
len
(
arrays
))]
)
):
raise
IndexError
(
"All arrays must have same number of dimensions"
)
# FIXME we limit the support here: we allow padding of only the last dimension
# need to remove this when feat extraction is updated to handle multichannel.
max_shape
=
[]
for
dim
in
range
(
arrays
[
0
].
ndim
):
if
dim
!=
(
arrays
[
0
].
ndim
-
1
):
if
not
all
(
[
x
.
shape
[
dim
]
==
arrays
[
0
].
shape
[
dim
]
for
x
in
arrays
[
1
:]]
):
raise
EnvironmentError
(
"arrays should have same dimensions except for last one"
)
max_shape
.
append
(
max
([
x
.
shape
[
dim
]
for
x
in
arrays
]))
batched
=
[]
valid
=
[]
for
t
in
arrays
:
# for each array we apply pad_right_to
padded
,
valid_percent
=
pad_right_to
(
t
,
max_shape
,
mode
=
mode
,
value
=
value
)
batched
.
append
(
padded
)
valid
.
append
(
valid_percent
[
-
1
])
batched
=
numpy
.
stack
(
batched
)
return
batched
,
numpy
.
array
(
valid
)
paddlespeech/vector/modules/loss.py
浏览文件 @
d28ccfa9
...
@@ -11,6 +11,8 @@
...
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# This is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/nnet/losses.py
import
math
import
math
import
paddle
import
paddle
...
@@ -20,6 +22,14 @@ import paddle.nn.functional as F
...
@@ -20,6 +22,14 @@ import paddle.nn.functional as F
class
AngularMargin
(
nn
.
Layer
):
class
AngularMargin
(
nn
.
Layer
):
def
__init__
(
self
,
margin
=
0.0
,
scale
=
1.0
):
def
__init__
(
self
,
margin
=
0.0
,
scale
=
1.0
):
"""An implementation of Angular Margin (AM) proposed in the following
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): The margin for cosine similiarity. Defaults to 0.0.
scale (float, optional): The scale for cosine similiarity. Defaults to 1.0.
"""
super
(
AngularMargin
,
self
).
__init__
()
super
(
AngularMargin
,
self
).
__init__
()
self
.
margin
=
margin
self
.
margin
=
margin
self
.
scale
=
scale
self
.
scale
=
scale
...
@@ -31,6 +41,15 @@ class AngularMargin(nn.Layer):
...
@@ -31,6 +41,15 @@ class AngularMargin(nn.Layer):
class
AdditiveAngularMargin
(
AngularMargin
):
class
AdditiveAngularMargin
(
AngularMargin
):
def
__init__
(
self
,
margin
=
0.0
,
scale
=
1.0
,
easy_margin
=
False
):
def
__init__
(
self
,
margin
=
0.0
,
scale
=
1.0
,
easy_margin
=
False
):
"""The Implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): margin factor. Defaults to 0.0.
scale (float, optional): scale factor. Defaults to 1.0.
easy_margin (bool, optional): easy_margin flag. Defaults to False.
"""
super
(
AdditiveAngularMargin
,
self
).
__init__
(
margin
,
scale
)
super
(
AdditiveAngularMargin
,
self
).
__init__
(
margin
,
scale
)
self
.
easy_margin
=
easy_margin
self
.
easy_margin
=
easy_margin
...
@@ -53,6 +72,11 @@ class AdditiveAngularMargin(AngularMargin):
...
@@ -53,6 +72,11 @@ class AdditiveAngularMargin(AngularMargin):
class
LogSoftmaxWrapper
(
nn
.
Layer
):
class
LogSoftmaxWrapper
(
nn
.
Layer
):
def
__init__
(
self
,
loss_fn
):
def
__init__
(
self
,
loss_fn
):
"""Speaker identificatin loss function wrapper
including all of compositions of the loss transformation
Args:
loss_fn (_type_): the loss value of a batch
"""
super
(
LogSoftmaxWrapper
,
self
).
__init__
()
super
(
LogSoftmaxWrapper
,
self
).
__init__
()
self
.
loss_fn
=
loss_fn
self
.
loss_fn
=
loss_fn
self
.
criterion
=
paddle
.
nn
.
KLDivLoss
(
reduction
=
"sum"
)
self
.
criterion
=
paddle
.
nn
.
KLDivLoss
(
reduction
=
"sum"
)
...
...
paddlespeech/vector/modules/sid_model.py
浏览文件 @
d28ccfa9
...
@@ -24,13 +24,25 @@ class SpeakerIdetification(nn.Layer):
...
@@ -24,13 +24,25 @@ class SpeakerIdetification(nn.Layer):
lin_blocks
=
0
,
lin_blocks
=
0
,
lin_neurons
=
192
,
lin_neurons
=
192
,
dropout
=
0.1
,
):
dropout
=
0.1
,
):
"""_summary_
Args:
backbone (Paddle.nn.Layer class): the speaker identification backbone network model
num_class (_type_): the speaker class num in the training dataset
lin_blocks (int, optional): the linear layer transform between the embedding and the final linear layer. Defaults to 0.
lin_neurons (int, optional): the output dimension of final linear layer. Defaults to 192.
dropout (float, optional): the dropout factor on the embedding. Defaults to 0.1.
"""
super
(
SpeakerIdetification
,
self
).
__init__
()
super
(
SpeakerIdetification
,
self
).
__init__
()
# speaker idenfication backbone network model
# the output of the backbond network is the target embedding
self
.
backbone
=
backbone
self
.
backbone
=
backbone
if
dropout
>
0
:
if
dropout
>
0
:
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
else
:
else
:
self
.
dropout
=
None
self
.
dropout
=
None
# construct the speaker classifer
input_size
=
self
.
backbone
.
emb_size
input_size
=
self
.
backbone
.
emb_size
self
.
blocks
=
nn
.
LayerList
()
self
.
blocks
=
nn
.
LayerList
()
for
i
in
range
(
lin_blocks
):
for
i
in
range
(
lin_blocks
):
...
@@ -40,12 +52,26 @@ class SpeakerIdetification(nn.Layer):
...
@@ -40,12 +52,26 @@ class SpeakerIdetification(nn.Layer):
])
])
input_size
=
lin_neurons
input_size
=
lin_neurons
# the final layer
self
.
weight
=
paddle
.
create_parameter
(
self
.
weight
=
paddle
.
create_parameter
(
shape
=
(
input_size
,
num_class
),
shape
=
(
input_size
,
num_class
),
dtype
=
'float32'
,
dtype
=
'float32'
,
attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
XavierUniform
()),
)
attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
XavierUniform
()),
)
def
forward
(
self
,
x
,
lengths
=
None
):
def
forward
(
self
,
x
,
lengths
=
None
):
"""Do the speaker identification model forwrd,
including the speaker embedding model and the classifier model network
Args:
x (Paddle.Tensor): input audio feats,
shape=[batch, dimension, times]
lengths (_type_, optional): input audio length.
shape=[batch, times]
Defaults to None.
Returns:
_type_: _description_
"""
# x.shape: (N, C, L)
# x.shape: (N, C, L)
x
=
self
.
backbone
(
x
,
lengths
).
squeeze
(
x
=
self
.
backbone
(
x
,
lengths
).
squeeze
(
-
1
)
# (N, emb_size, 1) -> (N, emb_size)
-
1
)
# (N, emb_size, 1) -> (N, emb_size)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录