提交 506d26a9 编写于 作者: X xiongxinlei

change the code style to s2t code style, test=doc

上级 7eb8fa72
########################################### ###########################################
# Data # # Data #
########################################### ###########################################
batch_size: 32 # we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path:
augment: True
batch_size: 16
num_workers: 2 num_workers: 2
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True shuffle: True
...@@ -11,10 +14,10 @@ random_chunk: True ...@@ -11,10 +14,10 @@ random_chunk: True
# FEATURE EXTRACTION SETTING # # FEATURE EXTRACTION SETTING #
########################################################### ###########################################################
# currently, we only support fbank # currently, we only support fbank
feature: sample_rate: 16000
n_mels: 80 n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
########################################################### ###########################################################
# MODEL SETTING # # MODEL SETTING #
...@@ -35,6 +38,15 @@ model: ...@@ -35,6 +38,15 @@ model:
########################################### ###########################################
seed: 1986 # according from speechbrain configuration seed: 1986 # according from speechbrain configuration
epochs: 10 epochs: 10
save_interval: 10 save_interval: 1
log_interval: 10 log_interval: 1
learning_rate: 1e-8 learning_rate: 1e-8
###########################################
# Testing #
###########################################
global_embedding_norm: True
embedding_mean_norm: True
embedding_std_norm: False
#!/bin/bash
stage=-1
stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
dir=$1
conf_path=$2
mkdir -p ${dir}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3 local/data_prepare.py \
--data-dir ${dir} \
--config ${conf_path}
fi
\ No newline at end of file
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import argparse import argparse
import os import os
import numpy as np
import paddle import paddle
from yacs.config import CfgNode
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
...@@ -25,46 +25,47 @@ from paddlespeech.vector.training.seeding import seed_everything ...@@ -25,46 +25,47 @@ from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def main(args): def main(args, config):
# stage0: set the cpu device, all data prepare process will be done in cpu mode # stage0: set the cpu device, all data prepare process will be done in cpu mode
paddle.set_device("cpu") paddle.set_device("cpu")
# set the random seed, it is a must for multiprocess training # set the random seed, it is a must for multiprocess training
seed_everything(args.seed) seed_everything(config.seed)
# stage 1: generate the voxceleb csv file # stage 1: generate the voxceleb csv file
# Note: this may occurs c++ execption, but the program will execute fine # Note: this may occurs c++ execption, but the program will execute fine
# so we ignore the execption # so we ignore the execption
# we explicitly pass the vox2 base path to data prepare and generate the audio info # we explicitly pass the vox2 base path to data prepare and generate the audio info
logger.info("start to generate the voxceleb dataset info")
train_dataset = VoxCeleb( train_dataset = VoxCeleb(
'train', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path) 'train', target_dir=args.data_dir, vox2_base_path=config.vox2_base_path)
dev_dataset = VoxCeleb(
'dev', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path)
# stage 2: generate the augment noise csv file # stage 2: generate the augment noise csv file
if args.augment: if config.augment:
logger.info("start to generate the augment dataset info")
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--seed",
default=0,
type=int,
help="random seed for paddle, numpy and python random package")
parser.add_argument("--data-dir", parser.add_argument("--data-dir",
default="./data/", default="./data/",
type=str, type=str,
help="data directory") help="data directory")
parser.add_argument("--vox2-base-path", parser.add_argument("--config",
default=None, default=None,
type=str, type=str,
help="vox2 base path, where is store the wav audio") help="configuration file")
parser.add_argument("--augment",
action="store_true",
default=False,
help="Apply audio augments.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
main(args)
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)
#!/bin/bash
. ./path.sh
exp_dir=exp/ecapa-tdnn-vox12-big//epoch_10/ # experiment directory
conf_path=conf/ecapa_tdnn.yaml
audio_path="demo/voxceleb/00001.wav"
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# extract the audio embedding
python3 ${BIN_DIR}/extract_emb.py --device "gpu" \
--config ${conf_path} \
--audio-path ${audio_path} --load-checkpoint ${exp_dir}
\ No newline at end of file
dir=$1
exp_dir=$2
conf_path=$3
python3 ${BIN_DIR}/test.py \
--config ${conf_path} \
--data-dir ${dir} \
--load-checkpoint ${exp_dir}
\ No newline at end of file
#!/bin/bash
dir=$1
exp_dir=$2
conf_path=$3
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
# train the speaker identification task with voxceleb data
# Note: we will store the log file in exp/log directory
python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \
${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \
--data-dir ${dir} --config ${conf_path}
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
\ No newline at end of file
...@@ -18,7 +18,7 @@ set -e ...@@ -18,7 +18,7 @@ set -e
####################################################################### #######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv # stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md # voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh
# stage 1: train the speaker identification model # stage 1: train the speaker identification model
# stage 2: test speaker identification # stage 2: test speaker identification
# stage 3: extract the training embeding to train the LDA and PLDA # stage 3: extract the training embeding to train the LDA and PLDA
...@@ -30,49 +30,39 @@ set -e ...@@ -30,49 +30,39 @@ set -e
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2 # and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav # we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME= # export PPAUDIO_HOME=
stage=0 stage=0
stop_stage=50
# data directory # data directory
# if we set the variable ${dir}, we will store the wav info to this directory # if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively # otherwise, we will store the wav info to vox1 and vox2 directory respectively
dir=data/
exp_dir=exp/ecapa-tdnn/ # experiment directory
# vox2 wav path, we must convert the m4a format to wav format # vox2 wav path, we must convert the m4a format to wav format
# and store them in the ${PPAUDIO_HOME}/datasets/vox2/wav/ directory # dir=data-demo/ # data info directory
vox2_base_path=${PPAUDIO_HOME}/datasets/vox2/wav/ dir=demo/ # data info directory
mkdir -p ${dir}
exp_dir=exp/ecapa-tdnn-vox12-big// # experiment directory
conf_path=conf/ecapa_tdnn.yaml
gpus=0,1,2,3
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
mkdir -p ${exp_dir} mkdir -p ${exp_dir}
if [ $stage -le 0 ]; then if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3 local/data_prepare.py \ # and we should specifiy the vox2 data in the data.sh
--data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} \ bash ./local/data.sh ${dir} ${conf_path}|| exit -1;
--config conf/ecapa_tdnn.yaml
fi fi
if [ $stage -le 1 ]; then if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# stage 1: train the speaker identification model # stage 1: train the speaker identification model
python3 \ CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path}
-m paddle.distributed.launch --gpus=0,1,2,3 \
${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \
--data-dir ${dir} --config conf/ecapa_tdnn.yaml
fi fi
if [ $stage -le 2 ]; then if [ $stage -le 2 ]; then
# stage 1: get the speaker verification scores with cosine function # stage 2: get the speaker verification scores with cosine function
python3 \ # now we only support use cosine to get the scores
${BIN_DIR}/speaker_verification_cosine.py\ CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path}
--config conf/ecapa_tdnn.yaml \
--data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/
fi
if [ $stage -le 3 ]; then
# stage 3: extract the audio embedding
python3 \
${BIN_DIR}/extract_speaker_embedding.py\
--config conf/ecapa_tdnn.yaml \
--audio-path "demo/csv/00001.wav" --load-checkpoint ${exp_dir}/epoch_60/
fi fi
# if [ $stage -le 3 ]; then # if [ $stage -le 3 ]; then
......
...@@ -25,13 +25,10 @@ from tqdm import tqdm ...@@ -25,13 +25,10 @@ from tqdm import tqdm
from ..backends import load as load_audio from ..backends import load as load_audio
from ..backends import save as save_wav from ..backends import save as save_wav
from .dataset import feat_funcs
from ..utils import DATA_HOME from ..utils import DATA_HOME
from ..utils import decompress from ..utils import decompress
from paddlespeech.s2t.utils.log import Log from ..utils.download import download_and_decompress
from paddlespeech.vector.utils.download import download_and_decompress from .dataset import feat_funcs
logger = Log(__name__).getlog()
__all__ = ['OpenRIRNoise'] __all__ = ['OpenRIRNoise']
...@@ -80,17 +77,17 @@ class OpenRIRNoise(Dataset): ...@@ -80,17 +77,17 @@ class OpenRIRNoise(Dataset):
def _get_data(self): def _get_data(self):
# Download audio files. # Download audio files.
logger.info(f"rirs noises base path: {self.base_path}") print(f"rirs noises base path: {self.base_path}")
if not os.path.isdir(self.base_path): if not os.path.isdir(self.base_path):
download_and_decompress( download_and_decompress(
self.archieves, self.base_path, decompress=True) self.archieves, self.base_path, decompress=True)
else: else:
logger.info( print(
f"{self.base_path} already exists, we will not download and decompress again" f"{self.base_path} already exists, we will not download and decompress again"
) )
# Data preparation. # Data preparation.
logger.info(f"prepare the csv to {self.csv_path}") print(f"prepare the csv to {self.csv_path}")
if not os.path.isdir(self.csv_path): if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path) os.makedirs(self.csv_path)
self.prepare_data() self.prepare_data()
...@@ -161,7 +158,7 @@ class OpenRIRNoise(Dataset): ...@@ -161,7 +158,7 @@ class OpenRIRNoise(Dataset):
wav_files: List[str], wav_files: List[str],
output_file: str, output_file: str,
split_chunks: bool=True): split_chunks: bool=True):
logger.info(f'Generating csv: {output_file}') print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav"] header = ["id", "duration", "wav"]
infos = list( infos = list(
......
...@@ -28,13 +28,8 @@ from tqdm import tqdm ...@@ -28,13 +28,8 @@ from tqdm import tqdm
from ..backends import load as load_audio from ..backends import load as load_audio
from ..utils import DATA_HOME from ..utils import DATA_HOME
from ..utils import decompress from ..utils import decompress
from ..utils.download import download_and_decompress
from .dataset import feat_funcs from .dataset import feat_funcs
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.download import download_and_decompress
from utils.utility import download
from utils.utility import unpack
logger = Log(__name__).getlog()
__all__ = ['VoxCeleb'] __all__ = ['VoxCeleb']
...@@ -138,9 +133,9 @@ class VoxCeleb(Dataset): ...@@ -138,9 +133,9 @@ class VoxCeleb(Dataset):
# Download audio files. # Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir # We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status # so, we check the vox1/wav dir status
logger.info(f"wav base path: {self.wav_path}") print(f"wav base path: {self.wav_path}")
if not os.path.isdir(self.wav_path): if not os.path.isdir(self.wav_path):
logger.info(f"start to download the voxceleb1 dataset") print(f"start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev, self.archieves_audio_dev,
self.base_path, self.base_path,
...@@ -152,7 +147,7 @@ class VoxCeleb(Dataset): ...@@ -152,7 +147,7 @@ class VoxCeleb(Dataset):
# Download all parts and concatenate the files into one zip file. # Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip') dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
logger.info(f'Concatenating all parts to: {dev_zipfile}') print(f'Concatenating all parts to: {dev_zipfile}')
os.system( os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}' f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
) )
...@@ -162,6 +157,7 @@ class VoxCeleb(Dataset): ...@@ -162,6 +157,7 @@ class VoxCeleb(Dataset):
# Download meta files. # Download meta files.
if not os.path.isdir(self.meta_path): if not os.path.isdir(self.meta_path):
print("prepare the meta data")
download_and_decompress( download_and_decompress(
self.archieves_meta, self.meta_path, decompress=False) self.archieves_meta, self.meta_path, decompress=False)
...@@ -171,7 +167,7 @@ class VoxCeleb(Dataset): ...@@ -171,7 +167,7 @@ class VoxCeleb(Dataset):
self.prepare_data() self.prepare_data()
data = [] data = []
logger.info( print(
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}" f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
) )
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
...@@ -266,8 +262,8 @@ class VoxCeleb(Dataset): ...@@ -266,8 +262,8 @@ class VoxCeleb(Dataset):
wav_files: List[str], wav_files: List[str],
output_file: str, output_file: str,
split_chunks: bool=True): split_chunks: bool=True):
logger.info(f'Generating csv: {output_file}') print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"] header = ["ID", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine # Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption # so we can ignore the execption
with Pool(cpu_count()) as p: with Pool(cpu_count()) as p:
...@@ -290,7 +286,7 @@ class VoxCeleb(Dataset): ...@@ -290,7 +286,7 @@ class VoxCeleb(Dataset):
def prepare_data(self): def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set. # Audio of speakers in veri_test_file should not be included in training set.
logger.info("start to prepare the data csv file") print("start to prepare the data csv file")
enroll_files = set() enroll_files = set()
test_files = set() test_files = set()
# get the enroll and test audio file path # get the enroll and test audio file path
...@@ -311,13 +307,13 @@ class VoxCeleb(Dataset): ...@@ -311,13 +307,13 @@ class VoxCeleb(Dataset):
# get all the train and dev audios file path # get all the train and dev audios file path
audio_files = [] audio_files = []
speakers = set() speakers = set()
print("Getting file list...")
for path in [self.wav_path, self.vox2_base_path]: for path in [self.wav_path, self.vox2_base_path]:
# if vox2 directory is not set and vox2 is not a directory # if vox2 directory is not set and vox2 is not a directory
# we will not process this directory # we will not process this directory
if not path or not os.path.exists(path): if not path or not os.path.exists(path):
logger.warning( print(f"{path} is an invalid path, please check again, "
f"{path} is an invalid path, please check again, " "and we will ignore the vox2 base path")
"and we will ignore the vox2 base path")
continue continue
for file in glob.glob( for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True): os.path.join(path, "**", "*.wav"), recursive=True):
...@@ -327,7 +323,7 @@ class VoxCeleb(Dataset): ...@@ -327,7 +323,7 @@ class VoxCeleb(Dataset):
speakers.add(spk) speakers.add(spk)
audio_files.append(file) audio_files.append(file)
logger.info( print(
f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}" f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
) )
# encode the train and dev speakers label to spk_id2label.txt # encode the train and dev speakers label to spk_id2label.txt
......
...@@ -37,7 +37,9 @@ def decompress(file: str): ...@@ -37,7 +37,9 @@ def decompress(file: str):
download._decompress(file) download._decompress(file)
def download_and_decompress(archives: List[Dict[str, str]], path: str): def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
""" """
Download archieves and decompress to specific path. Download archieves and decompress to specific path.
""" """
...@@ -47,8 +49,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str): ...@@ -47,8 +49,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str):
for archive in archives: for archive in archives:
assert 'url' in archive and 'md5' in archive, \ assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
download.get_path_from_url(archive['url'], path, archive['md5']) archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None): def load_state_dict_from_url(url: str, path: str, md5: str=None):
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
import argparse import argparse
import os import os
import time
import numpy as np import numpy as np
import paddle import paddle
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.paddleaudio.backends import load as load_audio from paddleaudio.backends import load as load_audio
from paddleaudio.paddleaudio.compliance.librosa import melspectrogram from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
...@@ -39,7 +40,7 @@ def extract_audio_embedding(args, config): ...@@ -39,7 +40,7 @@ def extract_audio_embedding(args, config):
ecapa_tdnn = EcapaTdnn(**config.model) ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model # stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1211) model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage 2: load the pre-trained model # stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath( args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint)) os.path.expanduser(args.load_checkpoint))
...@@ -60,7 +61,12 @@ def extract_audio_embedding(args, config): ...@@ -60,7 +61,12 @@ def extract_audio_embedding(args, config):
# feat type is numpy array, whose shape is [dim, time] # feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one # we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time] # so the final shape is [1, dim, time]
feat = melspectrogram(x=waveform, **config.feature) start_time = time.time()
feat = melspectrogram(x=waveform,
sr=config.sample_rate,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
feat = paddle.to_tensor(feat).unsqueeze(0) feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding # in inference period, the lengths is all one without padding
...@@ -71,9 +77,13 @@ def extract_audio_embedding(args, config): ...@@ -71,9 +77,13 @@ def extract_audio_embedding(args, config):
# model backbone network forward the feats and get the embedding # model backbone network forward the feats and get the embedding
embedding = model.backbone( embedding = model.backbone(
feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size) feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size)
elapsed_time = time.time() - start_time
audio_length = waveform.shape[0] / sr
# stage 5: do global norm with external mean and std # stage 5: do global norm with external mean and std
# todo rtf = elapsed_time / audio_length
logger.info(f"{args.device} rft={rtf}")
return embedding return embedding
...@@ -92,10 +102,6 @@ if __name__ == "__main__": ...@@ -92,10 +102,6 @@ if __name__ == "__main__":
type=str, type=str,
default='', default='',
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm",
type=str,
default=None,
help="Apply global normalization on speaker embeddings.")
parser.add_argument("--audio-path", parser.add_argument("--audio-path",
default="./data/demo.wav", default="./data/demo.wav",
type=str, type=str,
......
...@@ -23,8 +23,8 @@ from paddle.io import DataLoader ...@@ -23,8 +23,8 @@ from paddle.io import DataLoader
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.paddleaudio.datasets import VoxCeleb from paddleaudio.datasets import VoxCeleb
from paddleaudio.paddleaudio.metric import compute_eer from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
...@@ -48,6 +48,9 @@ def main(args, config): ...@@ -48,6 +48,9 @@ def main(args, config):
backbone=ecapa_tdnn, num_class=config.num_speakers) backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage3: load the pre-trained model # stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
last_save_epoch = (config.epochs // config.save_interval) * config.save_interval
args.load_checkpoint = os.path.join(args.load_checkpoint, "epoch_" + str(last_save_epoch))
args.load_checkpoint = os.path.abspath( args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint)) os.path.expanduser(args.load_checkpoint))
...@@ -63,7 +66,9 @@ def main(args, config): ...@@ -63,7 +66,9 @@ def main(args, config):
target_dir=args.data_dir, target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
random_chunk=False, random_chunk=False,
**config.feature) n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
enroll_sampler = BatchSampler( enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size, enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust. shuffle=True) # Shuffle to make embedding normalization more robust.
...@@ -73,13 +78,14 @@ def main(args, config): ...@@ -73,13 +78,14 @@ def main(args, config):
x, mean_norm=True, std_norm=False), x, mean_norm=True, std_norm=False),
num_workers=config.num_workers, num_workers=config.num_workers,
return_list=True,) return_list=True,)
test_dataset = VoxCeleb( test_dataset = VoxCeleb(
subset='test', subset='test',
target_dir=args.data_dir, target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
random_chunk=False, random_chunk=False,
**config.feature) n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
test_sampler = BatchSampler( test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True) test_dataset, batch_size=config.batch_size, shuffle=True)
...@@ -89,19 +95,19 @@ def main(args, config): ...@@ -89,19 +95,19 @@ def main(args, config):
x, mean_norm=True, std_norm=False), x, mean_norm=True, std_norm=False),
num_workers=config.num_workers, num_workers=config.num_workers,
return_list=True,) return_list=True,)
# stage6: we must set the model to eval mode # stage5: we must set the model to eval mode
model.eval() model.eval()
# stage7: global embedding norm to imporve the performance # stage6: global embedding norm to imporve the performance
print("global embedding norm: {}".format(args.global_embedding_norm)) logger.info(f"global embedding norm: {config.global_embedding_norm}")
if args.global_embedding_norm: if config.global_embedding_norm:
global_embedding_mean = None global_embedding_mean = None
global_embedding_std = None global_embedding_std = None
mean_norm_flag = args.embedding_mean_norm mean_norm_flag = config.embedding_mean_norm
std_norm_flag = args.embedding_std_norm std_norm_flag = config.embedding_std_norm
batch_count = 0 batch_count = 0
# stage8: Compute embeddings of audios in enrol and test dataset from model. # stage7: Compute embeddings of audios in enrol and test dataset from model.
id2embedding = {} id2embedding = {}
# Run multi times to make embedding normalization more stable. # Run multi times to make embedding normalization more stable.
for i in range(2): for i in range(2):
...@@ -121,7 +127,7 @@ def main(args, config): ...@@ -121,7 +127,7 @@ def main(args, config):
# Global embedding normalization. # Global embedding normalization.
# if we use the global embedding norm # if we use the global embedding norm
# eer can reduece about relative 10% # eer can reduece about relative 10%
if args.global_embedding_norm: if config.global_embedding_norm:
batch_count += 1 batch_count += 1
current_mean = embeddings.mean( current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0 axis=0) if mean_norm_flag else 0
...@@ -145,21 +151,22 @@ def main(args, config): ...@@ -145,21 +151,22 @@ def main(args, config):
# Update embedding dict. # Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings))) id2embedding.update(dict(zip(ids, embeddings)))
# stage 9: Compute cosine scores. # stage 8: Compute cosine scores.
labels = [] labels = []
enrol_ids = [] enroll_ids = []
test_ids = [] test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
with open(VoxCeleb.veri_test_file, 'r') as f: with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines(): for line in f.readlines():
label, enrol_id, test_id = line.strip().split(' ') label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label)) labels.append(int(label))
enrol_ids.append(enrol_id.split('.')[0].replace('/', '--')) enroll_ids.append(enroll_id.split('.')[0].replace('/', '-'))
test_ids.append(test_id.split('.')[0].replace('/', '--')) test_ids.append(test_id.split('.')[0].replace('/', '-'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1) cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
np.asarray([id2embedding[id] for id in ids], dtype='float32')), np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')),
[enrol_ids, test_ids [enroll_ids, test_ids
]) # (N, emb_size) ]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings) scores = cos_sim_func(enrol_embeddings, test_embeddings)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
...@@ -187,17 +194,6 @@ if __name__ == "__main__": ...@@ -187,17 +194,6 @@ if __name__ == "__main__":
type=str, type=str,
default='', default='',
help="Directory to load model checkpoint to contiune trainning.") help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm",
default=False,
action="store_true",
help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm",
default=True,
help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm",
type=bool,
default=False,
help="Apply std normalization on speaker embeddings.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
......
...@@ -21,8 +21,8 @@ from paddle.io import DataLoader ...@@ -21,8 +21,8 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.paddleaudio.compliance.librosa import melspectrogram from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.augment import waveform_augment
...@@ -68,6 +68,8 @@ def main(args, config): ...@@ -68,6 +68,8 @@ def main(args, config):
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer # stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
lr_schedule = CyclicLRScheduler( lr_schedule = CyclicLRScheduler(
base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks) base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.AdamW(
...@@ -138,6 +140,10 @@ def main(args, config): ...@@ -138,6 +140,10 @@ def main(args, config):
waveforms, labels = batch['waveforms'], batch['labels'] waveforms, labels = batch['waveforms'], batch['labels']
# stage 9-2: audio sample augment method, which is done on the audio sample point # stage 9-2: audio sample augment method, which is done on the audio sample point
# the original wavefrom and the augmented waveform is concatented in a batch
# eg. five augment method in the augment pipeline
# the final data nums is batch_size * [five + one]
# -> five augmented waveform batch plus one original batch waveform
if len(augment_pipeline) != 0: if len(augment_pipeline) != 0:
waveforms = waveform_augment(waveforms, augment_pipeline) waveforms = waveform_augment(waveforms, augment_pipeline)
labels = paddle.concat( labels = paddle.concat(
...@@ -146,7 +152,11 @@ def main(args, config): ...@@ -146,7 +152,11 @@ def main(args, config):
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = [] feats = []
for waveform in waveforms.numpy(): for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **config.feature) feat = melspectrogram(x=waveform,
sr=config.sample_rate,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
feats.append(feat) feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats)) feats = paddle.to_tensor(np.asarray(feats))
...@@ -205,7 +215,7 @@ def main(args, config): ...@@ -205,7 +215,7 @@ def main(args, config):
# stage 9-12: construct the valid dataset dataloader # stage 9-12: construct the valid dataset dataloader
dev_sampler = BatchSampler( dev_sampler = BatchSampler(
dev_dataset, dev_dataset,
batch_size=config.batch_size // 4, batch_size=config.batch_size,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
dev_loader = DataLoader( dev_loader = DataLoader(
...@@ -228,8 +238,11 @@ def main(args, config): ...@@ -228,8 +238,11 @@ def main(args, config):
feats = [] feats = []
for waveform in waveforms.numpy(): for waveform in waveforms.numpy():
# feat = melspectrogram(x=waveform, **cpu_feat_conf) feat = melspectrogram(x=waveform,
feat = melspectrogram(x=waveform, **config.feature) sr=config.sample_rate,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
feats.append(feat) feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats)) feats = paddle.to_tensor(np.asarray(feats))
......
...@@ -22,8 +22,8 @@ import paddle ...@@ -22,8 +22,8 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleaudio.paddleaudio import load as load_audio from paddleaudio import load as load_audio
from paddleaudio.paddleaudio.datasets.rirs_noises import OpenRIRNoise from paddleaudio.datasets.rirs_noises import OpenRIRNoise
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.signal_processing import compute_amplitude from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d from paddlespeech.vector.io.signal_processing import convolve1d
...@@ -879,14 +879,18 @@ def waveform_augment(waveforms: paddle.Tensor, ...@@ -879,14 +879,18 @@ def waveform_augment(waveforms: paddle.Tensor,
"""process the augment pipeline and return all the waveforms """process the augment pipeline and return all the waveforms
Args: Args:
waveforms (paddle.Tensor): _description_ waveforms (paddle.Tensor): original batch waveform
augment_pipeline (List[paddle.nn.Layer]): _description_ augment_pipeline (List[paddle.nn.Layer]): agument pipeline process
Returns: Returns:
paddle.Tensor: _description_ paddle.Tensor: all the audio waveform including the original waveform and augmented waveform
""" """
# stage 0: store the original waveforms
waveforms_aug_list = [waveforms] waveforms_aug_list = [waveforms]
# augment the original batch waveform
for aug in augment_pipeline: for aug in augment_pipeline:
# stage 1: augment the data
waveforms_aug = aug(waveforms) # (N, L) waveforms_aug = aug(waveforms) # (N, L)
if waveforms_aug.shape[1] >= waveforms.shape[1]: if waveforms_aug.shape[1] >= waveforms.shape[1]:
# Trunc # Trunc
...@@ -897,6 +901,8 @@ def waveform_augment(waveforms: paddle.Tensor, ...@@ -897,6 +901,8 @@ def waveform_augment(waveforms: paddle.Tensor,
waveforms_aug = F.pad( waveforms_aug = F.pad(
waveforms_aug.unsqueeze(-1), [0, lengths_to_pad], waveforms_aug.unsqueeze(-1), [0, lengths_to_pad],
data_format='NLC').squeeze(-1) data_format='NLC').squeeze(-1)
# stage 2: append the augmented waveform into the list
waveforms_aug_list.append(waveforms_aug) waveforms_aug_list.append(waveforms_aug)
# get the all the waveforms
return paddle.concat(waveforms_aug_list, axis=0) return paddle.concat(waveforms_aug_list, axis=0)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from typing import List
from paddle.framework import load as load_state_dict
from paddle.utils import download
__all__ = [
'decompress',
'download_and_decompress',
'load_state_dict_from_url',
]
def decompress(file: str, path: str=os.PathLike):
"""
Extracts all files from a compressed file to specific path.
"""
assert os.path.isfile(file), "File: {} not exists.".format(file)
if path is None:
print("decompress the data: {}".format(file))
download._decompress(file)
else:
print("decompress the data: {} to {}".format(file, path))
if not os.path.isdir(path):
os.makedirs(path)
tmp_file = os.path.join(path, os.path.basename(file))
os.rename(file, tmp_file)
download._decompress(tmp_file)
os.rename(tmp_file, file)
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None):
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册