提交 8ed5c287 编写于 作者: X xiongxinlei

add vox2 data into VoxCeleb class

上级 584a2c0e
...@@ -23,39 +23,6 @@ VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech, ...@@ -23,39 +23,6 @@ VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech,
ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s
``` ```
``` shell
# copy this to root directory of data and
# chmod a+x convert.sh
# ./convert.sh
# https://unix.stackexchange.com/questions/103920/parallelize-a-bash-for-loop
open_sem(){
mkfifo pipe-$$
exec 3<>pipe-$$
rm pipe-$$
local i=$1
for((;i>0;i--)); do
printf %s 000 >&3
done
}
run_with_lock(){
local x
read -u 3 -n 3 x && ((0==x)) || exit $x
(
( "$@"; )
printf '%.3d' $? >&3
)&
}
N=32 # number of vCPU
open_sem $N
for f in $(find . -name "*.m4a"); do
run_with_lock ffmpeg -loglevel panic -i "$f" -ar 16000 "${f%.*}.wav"
done
```
You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once. You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once.
3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`) 3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`)
4.
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import numpy as np import numpy as np
import paddle import paddle
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def main(args): def main(args):
# stage0: set the cpu device, all data prepare process will be done in cpu mode # stage0: set the cpu device, all data prepare process will be done in cpu mode
paddle.set_device("cpu") paddle.set_device("cpu")
# set the random seed, it is a must for multiprocess training # set the random seed, it is a must for multiprocess training
...@@ -19,14 +34,18 @@ def main(args): ...@@ -19,14 +34,18 @@ def main(args):
# stage 1: generate the voxceleb csv file # stage 1: generate the voxceleb csv file
# Note: this may occurs c++ execption, but the program will execute fine # Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption # so we ignore the execption
train_dataset = VoxCeleb1('train', target_dir=args.data_dir) # we explicitly pass the vox2 base path to data prepare and generate the audio info
dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) train_dataset = VoxCeleb(
'train', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path)
dev_dataset = VoxCeleb(
'dev', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path)
# stage 2: generate the augment noise csv file # stage 2: generate the augment noise csv file
if args.augment: if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
if __name__ == "__main__": if __name__ == "__main__":
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
...@@ -38,10 +57,14 @@ if __name__ == "__main__": ...@@ -38,10 +57,14 @@ if __name__ == "__main__":
default="./data/", default="./data/",
type=str, type=str,
help="data directory") help="data directory")
parser.add_argument("--vox2-base-path",
default=None,
type=str,
help="vox2 base path, where is store the wav audio")
parser.add_argument("--augment", parser.add_argument("--augment",
action="store_true", action="store_true",
default=False, default=False,
help="Apply audio augments.") help="Apply audio augments.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
main(args) main(args)
\ No newline at end of file
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export MAIN_ROOT=`realpath ${PWD}/../../../` export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
...@@ -10,5 +24,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} ...@@ -10,5 +24,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=ecapa-tdnn MODEL=ecapa_tdnn
export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL} export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}
\ No newline at end of file
#!/bin/bash #!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh . ./path.sh
set -e set -e
...@@ -11,19 +24,30 @@ set -e ...@@ -11,19 +24,30 @@ set -e
# stage 3: extract the training embeding to train the LDA and PLDA # stage 3: extract the training embeding to train the LDA and PLDA
###################################################################### ######################################################################
# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset # we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
# default the dataset is the ~/.paddleaudio/ # default the dataset will be stored in the ~/.paddleaudio/
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME= # export PPAUDIO_HOME=
stage=0 stage=0
dir=data.bak/ # data directory # data directory
exp_dir=exp/ecapa-tdnn/ # experiment directory # if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
dir=data/
exp_dir=exp/ecapa-tdnn/ # experiment directory
# vox2 wav path, we must convert the m4a format to wav format
# and store them in the ${PPAUDIO_HOME}/datasets/vox2/wav/ directory
vox2_base_path=${PPAUDIO_HOME}/datasets/vox2/wav/
mkdir -p ${dir} mkdir -p ${dir}
mkdir -p ${exp_dir} mkdir -p ${exp_dir}
if [ $stage -le 0 ]; then if [ $stage -le 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
python3 local/data_prepare.py --data-dir ${dir} --augment python3 local/data_prepare.py \
--data-dir ${dir} --augment --vox2-base-path ${vox2_base_path}
fi fi
if [ $stage -le 1 ]; then if [ $stage -le 1 ]; then
......
...@@ -15,5 +15,5 @@ from .esc50 import ESC50 ...@@ -15,5 +15,5 @@ from .esc50 import ESC50
from .gtzan import GTZAN from .gtzan import GTZAN
from .tess import TESS from .tess import TESS
from .urban_sound import UrbanSound8K from .urban_sound import UrbanSound8K
from .voxceleb import VoxCeleb1 from .voxceleb import VoxCeleb
from .rirs_noises import OpenRIRNoise from .rirs_noises import OpenRIRNoise
...@@ -25,10 +25,10 @@ from paddle.io import Dataset ...@@ -25,10 +25,10 @@ from paddle.io import Dataset
from pathos.multiprocessing import Pool from pathos.multiprocessing import Pool
from tqdm import tqdm from tqdm import tqdm
from .dataset import feat_funcs
from ..backends import load as load_audio from ..backends import load as load_audio
from ..utils import DATA_HOME from ..utils import DATA_HOME
from ..utils import decompress from ..utils import decompress
from .dataset import feat_funcs
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.download import download_and_decompress from paddlespeech.vector.utils.download import download_and_decompress
from utils.utility import download from utils.utility import download
...@@ -36,10 +36,10 @@ from utils.utility import unpack ...@@ -36,10 +36,10 @@ from utils.utility import unpack
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['VoxCeleb1'] __all__ = ['VoxCeleb']
class VoxCeleb1(Dataset): class VoxCeleb(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/' source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [ archieves_audio_dev = [
{ {
...@@ -94,8 +94,18 @@ class VoxCeleb1(Dataset): ...@@ -94,8 +94,18 @@ class VoxCeleb1(Dataset):
split_ratio: float=0.9, # train split ratio split_ratio: float=0.9, # train split ratio
seed: int=0, seed: int=0,
target_dir: str=None, target_dir: str=None,
vox2_base_path=None,
**kwargs): **kwargs):
"""VoxCeleb data prepare and get the specific dataset audio info
Args:
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
"""
assert subset in self.subsets, \ assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
...@@ -106,19 +116,20 @@ class VoxCeleb1(Dataset): ...@@ -106,19 +116,20 @@ class VoxCeleb1(Dataset):
self.random_chunk = random_chunk self.random_chunk = random_chunk
self.chunk_duration = chunk_duration self.chunk_duration = chunk_duration
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else VoxCeleb1.base_path self.target_dir = target_dir if target_dir else VoxCeleb.base_path
self.vox2_base_path = vox2_base_path
# if we set the target dir, we will change the vox data info data from base path to target dir # if we set the target dir, we will change the vox data info data from base path to target dir
VoxCeleb1.csv_path = os.path.join( VoxCeleb.csv_path = os.path.join(
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb1.csv_path target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
VoxCeleb1.meta_path = os.path.join( VoxCeleb.meta_path = os.path.join(
target_dir, "voxceleb", target_dir, "voxceleb",
'meta') if target_dir else VoxCeleb1.meta_path 'meta') if target_dir else VoxCeleb.meta_path
VoxCeleb1.veri_test_file = os.path.join(VoxCeleb1.meta_path, VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
'veri_test2.txt') 'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test. # self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data() self._data = self._get_data()
super(VoxCeleb1, self).__init__() super(VoxCeleb, self).__init__()
# Set up a seed to reproduce training or predicting result. # Set up a seed to reproduce training or predicting result.
# random.seed(seed) # random.seed(seed)
...@@ -300,7 +311,14 @@ class VoxCeleb1(Dataset): ...@@ -300,7 +311,14 @@ class VoxCeleb1(Dataset):
# get all the train and dev audios file path # get all the train and dev audios file path
audio_files = [] audio_files = []
speakers = set() speakers = set()
for path in [self.wav_path]: for path in [self.wav_path, self.vox2_base_path]:
# if vox2 directory is not set and vox2 is not a directory
# we will not process this directory
if not path or not os.path.exists(path):
logger.warning(
f"{path} is an invalid path, please check again, "
"and we will ignore the vox2 base path")
continue
for file in glob.glob( for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True): os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0] spk = file.split('/wav/')[1].split('/')[0]
......
...@@ -28,6 +28,7 @@ from paddlespeech.vector.training.seeding import seed_everything ...@@ -28,6 +28,7 @@ from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def extract_audio_embedding(args, config): def extract_audio_embedding(args, config):
# stage 0: set the training device, cpu or gpu # stage 0: set the training device, cpu or gpu
paddle.set_device(args.device) paddle.set_device(args.device)
...@@ -83,7 +84,7 @@ if __name__ == "__main__": ...@@ -83,7 +84,7 @@ if __name__ == "__main__":
choices=['cpu', 'gpu'], choices=['cpu', 'gpu'],
default="gpu", default="gpu",
help="Select which device to train model, defaults to gpu.") help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config", parser.add_argument("--config",
default=None, default=None,
type=str, type=str,
help="configuration file") help="configuration file")
......
...@@ -17,15 +17,15 @@ import os ...@@ -17,15 +17,15 @@ import os
import numpy as np import numpy as np
import paddle import paddle
from yacs.config import CfgNode
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode
from paddleaudio.paddleaudio.datasets import VoxCeleb1 from paddleaudio.paddleaudio.datasets import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddleaudio.paddleaudio.metric import compute_eer from paddleaudio.paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
...@@ -33,6 +33,7 @@ from paddlespeech.vector.training.seeding import seed_everything ...@@ -33,6 +33,7 @@ from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def main(args, config): def main(args, config):
# stage0: set the training device, cpu or gpu # stage0: set the training device, cpu or gpu
paddle.set_device(args.device) paddle.set_device(args.device)
...@@ -44,7 +45,7 @@ def main(args, config): ...@@ -44,7 +45,7 @@ def main(args, config):
# stage2: build the speaker verification eval instance with backbone model # stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification( model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
# stage3: load the pre-trained model # stage3: load the pre-trained model
args.load_checkpoint = os.path.abspath( args.load_checkpoint = os.path.abspath(
...@@ -57,7 +58,7 @@ def main(args, config): ...@@ -57,7 +58,7 @@ def main(args, config):
logger.info(f'Checkpoint loaded from {args.load_checkpoint}') logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader # stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb1( enroll_dataset = VoxCeleb(
subset='enroll', subset='enroll',
target_dir=args.data_dir, target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
...@@ -73,7 +74,7 @@ def main(args, config): ...@@ -73,7 +74,7 @@ def main(args, config):
num_workers=config.num_workers, num_workers=config.num_workers,
return_list=True,) return_list=True,)
test_dataset = VoxCeleb1( test_dataset = VoxCeleb(
subset='test', subset='test',
target_dir=args.data_dir, target_dir=args.data_dir,
feat_type='melspectrogram', feat_type='melspectrogram',
...@@ -145,7 +146,7 @@ def main(args, config): ...@@ -145,7 +146,7 @@ def main(args, config):
labels = [] labels = []
enrol_ids = [] enrol_ids = []
test_ids = [] test_ids = []
with open(VoxCeleb1.veri_test_file, 'r') as f: with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines(): for line in f.readlines():
label, enrol_id, test_id = line.strip().split(' ') label, enrol_id, test_id = line.strip().split(' ')
labels.append(int(label)) labels.append(int(label))
...@@ -171,7 +172,7 @@ if __name__ == "__main__": ...@@ -171,7 +172,7 @@ if __name__ == "__main__":
choices=['cpu', 'gpu'], choices=['cpu', 'gpu'],
default="gpu", default="gpu",
help="Select which device to train model, defaults to gpu.") help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config", parser.add_argument("--config",
default=None, default=None,
type=str, type=str,
help="configuration file") help="configuration file")
......
...@@ -20,8 +20,9 @@ from paddle.io import BatchSampler ...@@ -20,8 +20,9 @@ from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from paddleaudio.paddleaudio.compliance.librosa import melspectrogram from paddleaudio.paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.augment import waveform_augment
...@@ -30,13 +31,14 @@ from paddlespeech.vector.io.batch import waveform_collate_fn ...@@ -30,13 +31,14 @@ from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer from paddlespeech.vector.utils.time import Timer
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def main(args, config): def main(args, config):
# stage0: set the training device, cpu or gpu # stage0: set the training device, cpu or gpu
paddle.set_device(args.device) paddle.set_device(args.device)
...@@ -50,8 +52,8 @@ def main(args, config): ...@@ -50,8 +52,8 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code # note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCeleb1('train', target_dir=args.data_dir) train_dataset = VoxCeleb('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
if args.augment: if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
...@@ -63,7 +65,7 @@ def main(args, config): ...@@ -63,7 +65,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model # stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification( model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer # stage5: build the optimizer, we now only construct the AdamW optimizer
lr_schedule = CyclicLRScheduler( lr_schedule = CyclicLRScheduler(
...@@ -263,7 +265,7 @@ if __name__ == "__main__": ...@@ -263,7 +265,7 @@ if __name__ == "__main__":
choices=['cpu', 'gpu'], choices=['cpu', 'gpu'],
default="cpu", default="cpu",
help="Select which device to train model, defaults to gpu.") help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config", parser.add_argument("--config",
default=None, default=None,
type=str, type=str,
help="configuration file") help="configuration file")
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# this is modified from https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math import math
import os import os
from typing import List from typing import List
......
...@@ -19,16 +19,6 @@ import paddle.nn.functional as F ...@@ -19,16 +19,6 @@ import paddle.nn.functional as F
def length_to_mask(length, max_len=None, dtype=None): def length_to_mask(length, max_len=None, dtype=None):
"""_summary_
Args:
length (_type_): _description_
max_len (_type_, optional): _description_. Defaults to None.
dtype (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
assert len(length.shape) == 1 assert len(length.shape) == 1
if max_len is None: if max_len is None:
...@@ -60,15 +50,15 @@ class Conv1d(nn.Layer): ...@@ -60,15 +50,15 @@ class Conv1d(nn.Layer):
"""_summary_ """_summary_
Args: Args:
in_channels (_type_): _description_ in_channels (int): intput channel or input data dimensions
out_channels (_type_): _description_ out_channels (int): output channel or output data dimensions
kernel_size (_type_): _description_ kernel_size (int): kernel size of 1-d convolution
stride (int, optional): _description_. Defaults to 1. stride (int, optional): strid in 1-d convolution . Defaults to 1.
padding (str, optional): _description_. Defaults to "same". padding (str, optional): padding value. Defaults to "same".
dilation (int, optional): _description_. Defaults to 1. dilation (int, optional): dilation in 1-d convolution. Defaults to 1.
groups (int, optional): _description_. Defaults to 1. groups (int, optional): groups in 1-d convolution. Defaults to 1.
bias (bool, optional): _description_. Defaults to True. bias (bool, optional): bias in 1-d convolution . Defaults to True.
padding_mode (str, optional): _description_. Defaults to "reflect". padding_mode (str, optional): padding mode. Defaults to "reflect".
""" """
super().__init__() super().__init__()
...@@ -89,17 +79,6 @@ class Conv1d(nn.Layer): ...@@ -89,17 +79,6 @@ class Conv1d(nn.Layer):
bias_attr=bias, ) bias_attr=bias, )
def forward(self, x): def forward(self, x):
"""_summary_
Args:
x (_type_): _description_
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
if self.padding == "same": if self.padding == "same":
x = self._manage_padding(x, self.kernel_size, self.dilation, x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride) self.stride)
...@@ -109,17 +88,6 @@ class Conv1d(nn.Layer): ...@@ -109,17 +88,6 @@ class Conv1d(nn.Layer):
return self.conv(x) return self.conv(x)
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
"""_summary_
Args:
x (_type_): _description_
kernel_size (int): _description_
dilation (int): _description_
stride (int): _description_
Returns:
_type_: _description_
"""
L_in = x.shape[-1] # Detecting input shape L_in = x.shape[-1] # Detecting input shape
padding = self._get_padding_elem(L_in, stride, kernel_size, padding = self._get_padding_elem(L_in, stride, kernel_size,
dilation) # Time padding dilation) # Time padding
...@@ -133,17 +101,6 @@ class Conv1d(nn.Layer): ...@@ -133,17 +101,6 @@ class Conv1d(nn.Layer):
stride: int, stride: int,
kernel_size: int, kernel_size: int,
dilation: int): dilation: int):
"""_summary_
Args:
L_in (int): _description_
stride (int): _description_
kernel_size (int): _description_
dilation (int): _description_
Returns:
_type_: _description_
"""
if stride > 1: if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation L_out = stride * (n_steps - 1) + kernel_size * dilation
...@@ -220,8 +177,8 @@ class Res2NetBlock(nn.Layer): ...@@ -220,8 +177,8 @@ class Res2NetBlock(nn.Layer):
Args: Args:
in_channels (int): input channels or input dimensions in_channels (int): input channels or input dimensions
out_channels (int): output channels or output dimensions out_channels (int): output channels or output dimensions
scale (int, optional): _description_. Defaults to 8. scale (int, optional): scale in res2net bolck. Defaults to 8.
dilation (int, optional): _description_. Defaults to 1. dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1.
""" """
super().__init__() super().__init__()
assert in_channels % scale == 0 assert in_channels % scale == 0
...@@ -358,15 +315,16 @@ class SERes2NetBlock(nn.Layer): ...@@ -358,15 +315,16 @@ class SERes2NetBlock(nn.Layer):
dilation=1, dilation=1,
activation=nn.ReLU, ): activation=nn.ReLU, ):
"""Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model """Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model
The paper is refered "Squeeze-and-Excitation Networks"
whose url is: https://arxiv.org/pdf/1709.01507.pdf
Args: Args:
in_channels (int): input channels or input data dimensions in_channels (int): input channels or input data dimensions
out_channels (_type_): _description_ out_channels (int): output channels or output data dimensions
res2net_scale (int, optional): _description_. Defaults to 8. res2net_scale (int, optional): scale in the res2net block. Defaults to 8.
se_channels (int, optional): _description_. Defaults to 128. se_channels (int, optional): embedding dimensions of res2net block. Defaults to 128.
kernel_size (int, optional): _description_. Defaults to 1. kernel_size (int, optional): kernel size of 1-d convolution in TDNN block. Defaults to 1.
dilation (int, optional): _description_. Defaults to 1. dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1.
activation (_type_, optional): _description_. Defaults to nn.ReLU. activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU.
""" """
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
...@@ -419,7 +377,21 @@ class EcapaTdnn(nn.Layer): ...@@ -419,7 +377,21 @@ class EcapaTdnn(nn.Layer):
res2net_scale=8, res2net_scale=8,
se_channels=128, se_channels=128,
global_context=True, ): global_context=True, ):
"""Implementation of ECAPA-TDNN backbone model network
The paper is refered as "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification"
whose url is: https://arxiv.org/abs/2005.07143
Args:
input_size (_type_): input fature dimension
lin_neurons (int, optional): speaker embedding size. Defaults to 192.
activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU.
channels (list, optional): inter embedding dimension. Defaults to [512, 512, 512, 512, 1536].
kernel_sizes (list, optional): kernel size of 1-d convolution in TDNN block . Defaults to [5, 3, 3, 3, 1].
dilations (list, optional): dilations of 1-d convolution in TDNN block. Defaults to [1, 2, 3, 4, 1].
attention_channels (int, optional): attention dimensions. Defaults to 128.
res2net_scale (int, optional): scale value in res2net. Defaults to 8.
se_channels (int, optional): dimensions of squeeze-excitation block. Defaults to 128.
global_context (bool, optional): global context flag. Defaults to True.
"""
super().__init__() super().__init__()
assert len(channels) == len(kernel_sizes) assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations) assert len(channels) == len(dilations)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册