Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
506d26a9
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
506d26a9
编写于
3月 14, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change the code style to s2t code style, test=doc
上级
7eb8fa72
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
216 addition
and
208 deletion
+216
-208
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+19
-7
examples/voxceleb/sv0/local/data.sh
examples/voxceleb/sv0/local/data.sh
+18
-0
examples/voxceleb/sv0/local/data_prepare.py
examples/voxceleb/sv0/local/data_prepare.py
+20
-19
examples/voxceleb/sv0/local/emb.sh
examples/voxceleb/sv0/local/emb.sh
+13
-0
examples/voxceleb/sv0/local/test.sh
examples/voxceleb/sv0/local/test.sh
+8
-0
examples/voxceleb/sv0/local/train.sh
examples/voxceleb/sv0/local/train.sh
+22
-0
examples/voxceleb/sv0/run.sh
examples/voxceleb/sv0/run.sh
+20
-30
paddleaudio/paddleaudio/datasets/rirs_noises.py
paddleaudio/paddleaudio/datasets/rirs_noises.py
+6
-9
paddleaudio/paddleaudio/datasets/voxceleb.py
paddleaudio/paddleaudio/datasets/voxceleb.py
+13
-17
paddleaudio/paddleaudio/utils/download.py
paddleaudio/paddleaudio/utils/download.py
+5
-3
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
+15
-9
paddlespeech/vector/exps/ecapa_tdnn/test.py
paddlespeech/vector/exps/ecapa_tdnn/test.py
+27
-31
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+19
-6
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+11
-5
paddlespeech/vector/utils/download.py
paddlespeech/vector/utils/download.py
+0
-72
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
506d26a9
###########################################
# Data #
###########################################
batch_size
:
32
# we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path
:
augment
:
True
batch_size
:
16
num_workers
:
2
num_speakers
:
7205
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
...
...
@@ -11,10 +14,10 @@ random_chunk: True
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
feature
:
n_mels
:
80
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
sample_rate
:
16000
n_mels
:
80
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #
...
...
@@ -35,6 +38,15 @@ model:
###########################################
seed
:
1986
# according from speechbrain configuration
epochs
:
10
save_interval
:
1
0
log_interval
:
1
0
save_interval
:
1
log_interval
:
1
learning_rate
:
1e-8
###########################################
# Testing #
###########################################
global_embedding_norm
:
True
embedding_mean_norm
:
True
embedding_std_norm
:
False
examples/voxceleb/sv0/local/data.sh
0 → 100755
浏览文件 @
506d26a9
#!/bin/bash
stage
=
-1
stop_stage
=
100
.
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
-1
;
dir
=
$1
conf_path
=
$2
mkdir
-p
${
dir
}
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3
local
/data_prepare.py
\
--data-dir
${
dir
}
\
--config
${
conf_path
}
fi
\ No newline at end of file
examples/voxceleb/sv0/local/data_prepare.py
浏览文件 @
506d26a9
...
...
@@ -14,10 +14,10 @@
import
argparse
import
os
import
numpy
as
np
import
paddle
from
yacs.config
import
CfgNode
from
paddleaudio.
paddleaudio.
datasets.voxceleb
import
VoxCeleb
from
paddleaudio.datasets.voxceleb
import
VoxCeleb
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.training.seeding
import
seed_everything
...
...
@@ -25,46 +25,47 @@ from paddlespeech.vector.training.seeding import seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
main
(
args
):
def
main
(
args
,
config
):
# stage0: set the cpu device, all data prepare process will be done in cpu mode
paddle
.
set_device
(
"cpu"
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
seed_everything
(
config
.
seed
)
# stage 1: generate the voxceleb csv file
# Note: this may occurs c++ execption, but the program will execute fine
# so we ignore the execption
# we explicitly pass the vox2 base path to data prepare and generate the audio info
logger
.
info
(
"start to generate the voxceleb dataset info"
)
train_dataset
=
VoxCeleb
(
'train'
,
target_dir
=
args
.
data_dir
,
vox2_base_path
=
args
.
vox2_base_path
)
dev_dataset
=
VoxCeleb
(
'dev'
,
target_dir
=
args
.
data_dir
,
vox2_base_path
=
args
.
vox2_base_path
)
'train'
,
target_dir
=
args
.
data_dir
,
vox2_base_path
=
config
.
vox2_base_path
)
# stage 2: generate the augment noise csv file
if
args
.
augment
:
if
config
.
augment
:
logger
.
info
(
"start to generate the augment dataset info"
)
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--
vox2-base-path
"
,
parser
.
add_argument
(
"--
config
"
,
default
=
None
,
type
=
str
,
help
=
"vox2 base path, where is store the wav audio"
)
parser
.
add_argument
(
"--augment"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Apply audio augments."
)
help
=
"configuration file"
)
args
=
parser
.
parse_args
()
# yapf: enable
main
(
args
)
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
main
(
args
,
config
)
examples/voxceleb/sv0/local/emb.sh
0 → 100755
浏览文件 @
506d26a9
#!/bin/bash
.
./path.sh
exp_dir
=
exp/ecapa-tdnn-vox12-big//epoch_10/
# experiment directory
conf_path
=
conf/ecapa_tdnn.yaml
audio_path
=
"demo/voxceleb/00001.wav"
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
# extract the audio embedding
python3
${
BIN_DIR
}
/extract_emb.py
--device
"gpu"
\
--config
${
conf_path
}
\
--audio-path
${
audio_path
}
--load-checkpoint
${
exp_dir
}
\ No newline at end of file
examples/voxceleb/sv0/local/test.sh
0 → 100644
浏览文件 @
506d26a9
dir
=
$1
exp_dir
=
$2
conf_path
=
$3
python3
${
BIN_DIR
}
/test.py
\
--config
${
conf_path
}
\
--data-dir
${
dir
}
\
--load-checkpoint
${
exp_dir
}
\ No newline at end of file
examples/voxceleb/sv0/local/train.sh
0 → 100755
浏览文件 @
506d26a9
#!/bin/bash
dir
=
$1
exp_dir
=
$2
conf_path
=
$3
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
# train the speaker identification task with voxceleb data
# Note: we will store the log file in exp/log directory
python3
-m
paddle.distributed.launch
--gpus
=
$CUDA_VISIBLE_DEVICES
\
${
BIN_DIR
}
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
--data-dir
${
dir
}
--config
${
conf_path
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
exit
1
fi
exit
0
\ No newline at end of file
examples/voxceleb/sv0/run.sh
浏览文件 @
506d26a9
...
...
@@ -18,7 +18,7 @@ set -e
#######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md
with the script local/convert.sh
# stage 1: train the speaker identification model
# stage 2: test speaker identification
# stage 3: extract the training embeding to train the LDA and PLDA
...
...
@@ -30,49 +30,39 @@ set -e
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME=
stage
=
0
stop_stage
=
50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
dir
=
data/
exp_dir
=
exp/ecapa-tdnn/
# experiment directory
# vox2 wav path, we must convert the m4a format to wav format
# and store them in the ${PPAUDIO_HOME}/datasets/vox2/wav/ directory
vox2_base_path
=
${
PPAUDIO_HOME
}
/datasets/vox2/wav/
mkdir
-p
${
dir
}
# dir=data-demo/ # data info directory
dir
=
demo/
# data info directory
exp_dir
=
exp/ecapa-tdnn-vox12-big//
# experiment directory
conf_path
=
conf/ecapa_tdnn.yaml
gpus
=
0,1,2,3
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
mkdir
-p
${
exp_dir
}
if
[
$stage
-le
0
]
;
then
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3
local
/data_prepare.py
\
--data-dir
${
dir
}
--augment
--vox2-base-path
${
vox2_base_path
}
\
--config
conf/ecapa_tdnn.yaml
# and we should specifiy the vox2 data in the data.sh
bash ./local/data.sh
${
dir
}
${
conf_path
}
||
exit
-1
;
fi
if
[
$stage
-le
1
]
;
then
if
[
$stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# stage 1: train the speaker identification model
python3
\
-m
paddle.distributed.launch
--gpus
=
0,1,2,3
\
${
BIN_DIR
}
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
--data-dir
${
dir
}
--config
conf/ecapa_tdnn.yaml
CUDA_VISIBLE_DEVICES
=
${
gpus
}
bash ./local/train.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
if
[
$stage
-le
2
]
;
then
# stage 1: get the speaker verification scores with cosine function
python3
\
${
BIN_DIR
}
/speaker_verification_cosine.py
\
--config
conf/ecapa_tdnn.yaml
\
--data-dir
${
dir
}
--load-checkpoint
${
exp_dir
}
/epoch_10/
fi
if
[
$stage
-le
3
]
;
then
# stage 3: extract the audio embedding
python3
\
${
BIN_DIR
}
/extract_speaker_embedding.py
\
--config
conf/ecapa_tdnn.yaml
\
--audio-path
"demo/csv/00001.wav"
--load-checkpoint
${
exp_dir
}
/epoch_60/
# stage 2: get the speaker verification scores with cosine function
# now we only support use cosine to get the scores
CUDA_VISIBLE_DEVICES
=
0 bash ./local/test.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
# if [ $stage -le 3 ]; then
...
...
paddleaudio/paddleaudio/datasets/rirs_noises.py
浏览文件 @
506d26a9
...
...
@@ -25,13 +25,10 @@ from tqdm import tqdm
from
..backends
import
load
as
load_audio
from
..backends
import
save
as
save_wav
from
.dataset
import
feat_funcs
from
..utils
import
DATA_HOME
from
..utils
import
decompress
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.download
import
download_and_decompress
logger
=
Log
(
__name__
).
getlog
()
from
..utils.download
import
download_and_decompress
from
.dataset
import
feat_funcs
__all__
=
[
'OpenRIRNoise'
]
...
...
@@ -80,17 +77,17 @@ class OpenRIRNoise(Dataset):
def
_get_data
(
self
):
# Download audio files.
logger
.
info
(
f
"rirs noises base path:
{
self
.
base_path
}
"
)
print
(
f
"rirs noises base path:
{
self
.
base_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
base_path
):
download_and_decompress
(
self
.
archieves
,
self
.
base_path
,
decompress
=
True
)
else
:
logger
.
info
(
print
(
f
"
{
self
.
base_path
}
already exists, we will not download and decompress again"
)
# Data preparation.
logger
.
info
(
f
"prepare the csv to
{
self
.
csv_path
}
"
)
print
(
f
"prepare the csv to
{
self
.
csv_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
csv_path
):
os
.
makedirs
(
self
.
csv_path
)
self
.
prepare_data
()
...
...
@@ -161,7 +158,7 @@ class OpenRIRNoise(Dataset):
wav_files
:
List
[
str
],
output_file
:
str
,
split_chunks
:
bool
=
True
):
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
print
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"id"
,
"duration"
,
"wav"
]
infos
=
list
(
...
...
paddleaudio/paddleaudio/datasets/voxceleb.py
浏览文件 @
506d26a9
...
...
@@ -28,13 +28,8 @@ from tqdm import tqdm
from
..backends
import
load
as
load_audio
from
..utils
import
DATA_HOME
from
..utils
import
decompress
from
..utils.download
import
download_and_decompress
from
.dataset
import
feat_funcs
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.download
import
download_and_decompress
from
utils.utility
import
download
from
utils.utility
import
unpack
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'VoxCeleb'
]
...
...
@@ -138,9 +133,9 @@ class VoxCeleb(Dataset):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
logger
.
info
(
f
"wav base path:
{
self
.
wav_path
}
"
)
print
(
f
"wav base path:
{
self
.
wav_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
wav_path
):
logger
.
info
(
f
"start to download the voxceleb1 dataset"
)
print
(
f
"start to download the voxceleb1 dataset"
)
download_and_decompress
(
# multi-zip parts concatenate to vox1_dev_wav.zip
self
.
archieves_audio_dev
,
self
.
base_path
,
...
...
@@ -152,7 +147,7 @@ class VoxCeleb(Dataset):
# Download all parts and concatenate the files into one zip file.
dev_zipfile
=
os
.
path
.
join
(
self
.
base_path
,
'vox1_dev_wav.zip'
)
logger
.
info
(
f
'Concatenating all parts to:
{
dev_zipfile
}
'
)
print
(
f
'Concatenating all parts to:
{
dev_zipfile
}
'
)
os
.
system
(
f
'cat
{
os
.
path
.
join
(
self
.
base_path
,
"vox1_dev_wav_parta*"
)
}
>
{
dev_zipfile
}
'
)
...
...
@@ -162,6 +157,7 @@ class VoxCeleb(Dataset):
# Download meta files.
if
not
os
.
path
.
isdir
(
self
.
meta_path
):
print
(
"prepare the meta data"
)
download_and_decompress
(
self
.
archieves_meta
,
self
.
meta_path
,
decompress
=
False
)
...
...
@@ -171,7 +167,7 @@ class VoxCeleb(Dataset):
self
.
prepare_data
()
data
=
[]
logger
.
info
(
print
(
f
"read the
{
self
.
subset
}
from
{
os
.
path
.
join
(
self
.
csv_path
,
f
'
{
self
.
subset
}
.
csv
')
}
"
)
with
open
(
os
.
path
.
join
(
self
.
csv_path
,
f
'
{
self
.
subset
}
.csv'
),
'r'
)
as
rf
:
...
...
@@ -266,8 +262,8 @@ class VoxCeleb(Dataset):
wav_files
:
List
[
str
],
output_file
:
str
,
split_chunks
:
bool
=
True
):
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"
id
"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"spk_id"
]
print
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"
ID
"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"spk_id"
]
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
with
Pool
(
cpu_count
())
as
p
:
...
...
@@ -290,7 +286,7 @@ class VoxCeleb(Dataset):
def
prepare_data
(
self
):
# Audio of speakers in veri_test_file should not be included in training set.
logger
.
info
(
"start to prepare the data csv file"
)
print
(
"start to prepare the data csv file"
)
enroll_files
=
set
()
test_files
=
set
()
# get the enroll and test audio file path
...
...
@@ -311,12 +307,12 @@ class VoxCeleb(Dataset):
# get all the train and dev audios file path
audio_files
=
[]
speakers
=
set
()
print
(
"Getting file list..."
)
for
path
in
[
self
.
wav_path
,
self
.
vox2_base_path
]:
# if vox2 directory is not set and vox2 is not a directory
# we will not process this directory
if
not
path
or
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"
{
path
}
is an invalid path, please check again, "
print
(
f
"
{
path
}
is an invalid path, please check again, "
"and we will ignore the vox2 base path"
)
continue
for
file
in
glob
.
glob
(
...
...
@@ -327,7 +323,7 @@ class VoxCeleb(Dataset):
speakers
.
add
(
spk
)
audio_files
.
append
(
file
)
logger
.
info
(
print
(
f
"start to generate the
{
os
.
path
.
join
(
self
.
meta_path
,
'spk_id2label.txt'
)
}
"
)
# encode the train and dev speakers label to spk_id2label.txt
...
...
paddleaudio/paddleaudio/utils/download.py
浏览文件 @
506d26a9
...
...
@@ -37,7 +37,9 @@ def decompress(file: str):
download
.
_decompress
(
file
)
def
download_and_decompress
(
archives
:
List
[
Dict
[
str
,
str
]],
path
:
str
):
def
download_and_decompress
(
archives
:
List
[
Dict
[
str
,
str
]],
path
:
str
,
decompress
:
bool
=
True
):
"""
Download archieves and decompress to specific path.
"""
...
...
@@ -47,8 +49,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str):
for
archive
in
archives
:
assert
'url'
in
archive
and
'md5'
in
archive
,
\
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download
.
get_path_from_url
(
archive
[
'url'
],
path
,
archive
[
'md5'
]
)
download
.
get_path_from_url
(
archive
[
'url'
],
path
,
archive
[
'md5'
],
decompress
=
decompress
)
def
load_state_dict_from_url
(
url
:
str
,
path
:
str
,
md5
:
str
=
None
):
...
...
paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
浏览文件 @
506d26a9
...
...
@@ -14,12 +14,13 @@
import
argparse
import
os
import
time
import
numpy
as
np
import
paddle
from
yacs.config
import
CfgNode
from
paddleaudio.
paddleaudio.
backends
import
load
as
load_audio
from
paddleaudio.
paddleaudio.
compliance.librosa
import
melspectrogram
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.compliance.librosa
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
...
...
@@ -39,7 +40,7 @@ def extract_audio_embedding(args, config):
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
1211
)
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage 2: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
...
...
@@ -60,7 +61,12 @@ def extract_audio_embedding(args, config):
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
start_time
=
time
.
time
()
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
sample_rate
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_length
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
...
...
@@ -71,9 +77,13 @@ def extract_audio_embedding(args, config):
# model backbone network forward the feats and get the embedding
embedding
=
model
.
backbone
(
feat
,
lengths
).
squeeze
().
numpy
()
# (1, emb_size, 1) -> (emb_size)
elapsed_time
=
time
.
time
()
-
start_time
audio_length
=
waveform
.
shape
[
0
]
/
sr
# stage 5: do global norm with external mean and std
# todo
rtf
=
elapsed_time
/
audio_length
logger
.
info
(
f
"
{
args
.
device
}
rft=
{
rtf
}
"
)
return
embedding
...
...
@@ -92,10 +102,6 @@ if __name__ == "__main__":
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
str
,
default
=
None
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--audio-path"
,
default
=
"./data/demo.wav"
,
type
=
str
,
...
...
paddlespeech/vector/exps/ecapa_tdnn/
speaker_verification_cosine
.py
→
paddlespeech/vector/exps/ecapa_tdnn/
test
.py
浏览文件 @
506d26a9
...
...
@@ -23,8 +23,8 @@ from paddle.io import DataLoader
from
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
paddleaudio.
paddleaudio.
datasets
import
VoxCeleb
from
paddleaudio.
paddleaudio.
metric
import
compute_eer
from
paddleaudio.datasets
import
VoxCeleb
from
paddleaudio.metric
import
compute_eer
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.batch
import
batch_feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
...
...
@@ -48,6 +48,9 @@ def main(args, config):
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
last_save_epoch
=
(
config
.
epochs
//
config
.
save_interval
)
*
config
.
save_interval
args
.
load_checkpoint
=
os
.
path
.
join
(
args
.
load_checkpoint
,
"epoch_"
+
str
(
last_save_epoch
))
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
...
...
@@ -63,7 +66,9 @@ def main(args, config):
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
config
.
feature
)
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_length
)
enroll_sampler
=
BatchSampler
(
enroll_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
# Shuffle to make embedding normalization more robust.
...
...
@@ -73,13 +78,14 @@ def main(args, config):
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
test_dataset
=
VoxCeleb
(
subset
=
'test'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
config
.
feature
)
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_length
)
test_sampler
=
BatchSampler
(
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
...
...
@@ -89,19 +95,19 @@ def main(args, config):
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
# stage
6
: we must set the model to eval mode
# stage
5
: we must set the model to eval mode
model
.
eval
()
# stage
7
: global embedding norm to imporve the performance
print
(
"global embedding norm: {}"
.
format
(
args
.
global_embedding_norm
)
)
if
args
.
global_embedding_norm
:
# stage
6
: global embedding norm to imporve the performance
logger
.
info
(
f
"global embedding norm:
{
config
.
global_embedding_norm
}
"
)
if
config
.
global_embedding_norm
:
global_embedding_mean
=
None
global_embedding_std
=
None
mean_norm_flag
=
args
.
embedding_mean_norm
std_norm_flag
=
args
.
embedding_std_norm
mean_norm_flag
=
config
.
embedding_mean_norm
std_norm_flag
=
config
.
embedding_std_norm
batch_count
=
0
# stage
8
: Compute embeddings of audios in enrol and test dataset from model.
# stage
7
: Compute embeddings of audios in enrol and test dataset from model.
id2embedding
=
{}
# Run multi times to make embedding normalization more stable.
for
i
in
range
(
2
):
...
...
@@ -121,7 +127,7 @@ def main(args, config):
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if
args
.
global_embedding_norm
:
if
config
.
global_embedding_norm
:
batch_count
+=
1
current_mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm_flag
else
0
...
...
@@ -145,21 +151,22 @@ def main(args, config):
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
# stage
9
: Compute cosine scores.
# stage
8
: Compute cosine scores.
labels
=
[]
enrol_ids
=
[]
enrol
l
_ids
=
[]
test_ids
=
[]
logger
.
info
(
f
"read the trial from
{
VoxCeleb
.
veri_test_file
}
"
)
with
open
(
VoxCeleb
.
veri_test_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
label
,
enrol_id
,
test_id
=
line
.
strip
().
split
(
' '
)
label
,
enrol
l
_id
,
test_id
=
line
.
strip
().
split
(
' '
)
labels
.
append
(
int
(
label
))
enrol
_ids
.
append
(
enrol_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-
-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-
-
'
))
enrol
l_ids
.
append
(
enroll_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'
-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
np
.
asarray
([
id2embedding
[
id
]
for
id
in
ids
],
dtype
=
'float32'
)),
[
enrol_ids
,
test_ids
np
.
asarray
([
id2embedding
[
uttid
]
for
utt
id
in
ids
],
dtype
=
'float32'
)),
[
enrol
l
_ids
,
test_ids
])
# (N, emb_size)
scores
=
cos_sim_func
(
enrol_embeddings
,
test_embeddings
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
...
...
@@ -187,17 +194,6 @@ if __name__ == "__main__":
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-mean-norm"
,
default
=
True
,
help
=
"Apply mean normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-std-norm"
,
type
=
bool
,
default
=
False
,
help
=
"Apply std normalization on speaker embeddings."
)
args
=
parser
.
parse_args
()
# yapf: enable
# https://yaml.org/type/float.html
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
506d26a9
...
...
@@ -21,8 +21,8 @@ from paddle.io import DataLoader
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
paddleaudio.
paddleaudio.
compliance.librosa
import
melspectrogram
from
paddleaudio.
paddleaudio.
datasets.voxceleb
import
VoxCeleb
from
paddleaudio.compliance.librosa
import
melspectrogram
from
paddleaudio.datasets.voxceleb
import
VoxCeleb
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
waveform_augment
...
...
@@ -68,6 +68,8 @@ def main(args, config):
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb
.
num_speakers
)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
lr_schedule
=
CyclicLRScheduler
(
base_lr
=
config
.
learning_rate
,
max_lr
=
1e-3
,
step_size
=
140000
//
nranks
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
...
...
@@ -138,6 +140,10 @@ def main(args, config):
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
# stage 9-2: audio sample augment method, which is done on the audio sample point
# the original wavefrom and the augmented waveform is concatented in a batch
# eg. five augment method in the augment pipeline
# the final data nums is batch_size * [five + one]
# -> five augmented waveform batch plus one original batch waveform
if
len
(
augment_pipeline
)
!=
0
:
waveforms
=
waveform_augment
(
waveforms
,
augment_pipeline
)
labels
=
paddle
.
concat
(
...
...
@@ -146,7 +152,11 @@ def main(args, config):
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
sample_rate
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_length
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
...
...
@@ -205,7 +215,7 @@ def main(args, config):
# stage 9-12: construct the valid dataset dataloader
dev_sampler
=
BatchSampler
(
dev_dataset
,
batch_size
=
config
.
batch_size
//
4
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
DataLoader
(
...
...
@@ -228,8 +238,11 @@ def main(args, config):
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
# feat = melspectrogram(x=waveform, **cpu_feat_conf)
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
feat
=
melspectrogram
(
x
=
waveform
,
sr
=
config
.
sample_rate
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_length
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
...
...
paddlespeech/vector/io/augment.py
浏览文件 @
506d26a9
...
...
@@ -22,8 +22,8 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleaudio
.paddleaudio
import
load
as
load_audio
from
paddleaudio.
paddleaudio.
datasets.rirs_noises
import
OpenRIRNoise
from
paddleaudio
import
load
as
load_audio
from
paddleaudio.datasets.rirs_noises
import
OpenRIRNoise
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
convolve1d
...
...
@@ -879,14 +879,18 @@ def waveform_augment(waveforms: paddle.Tensor,
"""process the augment pipeline and return all the waveforms
Args:
waveforms (paddle.Tensor):
_description_
augment_pipeline (List[paddle.nn.Layer]):
_description_
waveforms (paddle.Tensor):
original batch waveform
augment_pipeline (List[paddle.nn.Layer]):
agument pipeline process
Returns:
paddle.Tensor:
_description_
paddle.Tensor:
all the audio waveform including the original waveform and augmented waveform
"""
# stage 0: store the original waveforms
waveforms_aug_list
=
[
waveforms
]
# augment the original batch waveform
for
aug
in
augment_pipeline
:
# stage 1: augment the data
waveforms_aug
=
aug
(
waveforms
)
# (N, L)
if
waveforms_aug
.
shape
[
1
]
>=
waveforms
.
shape
[
1
]:
# Trunc
...
...
@@ -897,6 +901,8 @@ def waveform_augment(waveforms: paddle.Tensor,
waveforms_aug
=
F
.
pad
(
waveforms_aug
.
unsqueeze
(
-
1
),
[
0
,
lengths_to_pad
],
data_format
=
'NLC'
).
squeeze
(
-
1
)
# stage 2: append the augmented waveform into the list
waveforms_aug_list
.
append
(
waveforms_aug
)
# get the all the waveforms
return
paddle
.
concat
(
waveforms_aug_list
,
axis
=
0
)
paddlespeech/vector/utils/download.py
已删除
100644 → 0
浏览文件 @
7eb8fa72
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Dict
from
typing
import
List
from
paddle.framework
import
load
as
load_state_dict
from
paddle.utils
import
download
__all__
=
[
'decompress'
,
'download_and_decompress'
,
'load_state_dict_from_url'
,
]
def
decompress
(
file
:
str
,
path
:
str
=
os
.
PathLike
):
"""
Extracts all files from a compressed file to specific path.
"""
assert
os
.
path
.
isfile
(
file
),
"File: {} not exists."
.
format
(
file
)
if
path
is
None
:
print
(
"decompress the data: {}"
.
format
(
file
))
download
.
_decompress
(
file
)
else
:
print
(
"decompress the data: {} to {}"
.
format
(
file
,
path
))
if
not
os
.
path
.
isdir
(
path
):
os
.
makedirs
(
path
)
tmp_file
=
os
.
path
.
join
(
path
,
os
.
path
.
basename
(
file
))
os
.
rename
(
file
,
tmp_file
)
download
.
_decompress
(
tmp_file
)
os
.
rename
(
tmp_file
,
file
)
def
download_and_decompress
(
archives
:
List
[
Dict
[
str
,
str
]],
path
:
str
,
decompress
:
bool
=
True
):
"""
Download archieves and decompress to specific path.
"""
if
not
os
.
path
.
isdir
(
path
):
os
.
makedirs
(
path
)
for
archive
in
archives
:
assert
'url'
in
archive
and
'md5'
in
archive
,
\
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download
.
get_path_from_url
(
archive
[
'url'
],
path
,
archive
[
'md5'
],
decompress
=
decompress
)
def
load_state_dict_from_url
(
url
:
str
,
path
:
str
,
md5
:
str
=
None
):
"""
Download and load a state dict from url
"""
if
not
os
.
path
.
isdir
(
path
):
os
.
makedirs
(
path
)
download
.
get_path_from_url
(
url
,
path
,
md5
)
return
load_state_dict
(
os
.
path
.
join
(
path
,
os
.
path
.
basename
(
url
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录