diff --git a/examples/voxceleb/README.md b/examples/voxceleb/README.md index 59fb491ce903bd037bc5a86ce2f3b89100db25af..fc847cd8a48a6faa81b63ab1433c9b34702e6a94 100644 --- a/examples/voxceleb/README.md +++ b/examples/voxceleb/README.md @@ -23,39 +23,6 @@ VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech, ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s ``` -``` shell -# copy this to root directory of data and -# chmod a+x convert.sh -# ./convert.sh -# https://unix.stackexchange.com/questions/103920/parallelize-a-bash-for-loop - -open_sem(){ - mkfifo pipe-$$ - exec 3<>pipe-$$ - rm pipe-$$ - local i=$1 - for((;i>0;i--)); do - printf %s 000 >&3 - done -} -run_with_lock(){ - local x - read -u 3 -n 3 x && ((0==x)) || exit $x - ( - ( "$@"; ) - printf '%.3d' $? >&3 - )& -} - -N=32 # number of vCPU -open_sem $N -for f in $(find . -name "*.m4a"); do - run_with_lock ffmpeg -loglevel panic -i "$f" -ar 16000 "${f%.*}.wav" -done -``` - You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once. 3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`) - -4. \ No newline at end of file diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py index 1a0a639275bfcea22f6b91fb20c0ff6a46529879..b906b5da4895559a602f94c1ac7f5e683daeea15 100644 --- a/examples/voxceleb/sv0/local/data_prepare.py +++ b/examples/voxceleb/sv0/local/data_prepare.py @@ -1,17 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import os import numpy as np import paddle -from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() + def main(args): + # stage0: set the cpu device, all data prepare process will be done in cpu mode paddle.set_device("cpu") # set the random seed, it is a must for multiprocess training @@ -19,14 +34,18 @@ def main(args): # stage 1: generate the voxceleb csv file # Note: this may occurs c++ execption, but the program will execute fine - # so we can ignore the execption - train_dataset = VoxCeleb1('train', target_dir=args.data_dir) - dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) + # so we ignore the execption + # we explicitly pass the vox2 base path to data prepare and generate the audio info + train_dataset = VoxCeleb( + 'train', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path) + dev_dataset = VoxCeleb( + 'dev', target_dir=args.data_dir, vox2_base_path=args.vox2_base_path) # stage 2: generate the augment noise csv file if args.augment: augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + if __name__ == "__main__": # yapf: disable parser = argparse.ArgumentParser(__doc__) @@ -38,10 +57,14 @@ if __name__ == "__main__": default="./data/", type=str, help="data directory") + parser.add_argument("--vox2-base-path", + default=None, + type=str, + help="vox2 base path, where is store the wav audio") parser.add_argument("--augment", action="store_true", default=False, help="Apply audio augments.") args = parser.parse_args() # yapf: enable - main(args) \ No newline at end of file + main(args) diff --git a/examples/voxceleb/sv0/path.sh b/examples/voxceleb/sv0/path.sh index 6d19f99482fc58865e4ebca20e4ff2b31f508487..2be098e04ec2dc8e2b88111d1cf713f7b7978677 100755 --- a/examples/voxceleb/sv0/path.sh +++ b/examples/voxceleb/sv0/path.sh @@ -1,3 +1,17 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} @@ -10,5 +24,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ -MODEL=ecapa-tdnn +MODEL=ecapa_tdnn export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL} \ No newline at end of file diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index 2c0e55a659f864167b67595c8b7aec22a89ea495..769332eb793879e04fa80d607921ea05ea8ee174 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -1,4 +1,17 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. . ./path.sh set -e @@ -11,19 +24,30 @@ set -e # stage 3: extract the training embeding to train the LDA and PLDA ###################################################################### -# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset -# default the dataset is the ~/.paddleaudio/ +# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset +# default the dataset will be stored in the ~/.paddleaudio/ +# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself +# and put all of them to ${PPAUDIO_HOME}/datasets/vox2 +# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav # export PPAUDIO_HOME= stage=0 -dir=data.bak/ # data directory -exp_dir=exp/ecapa-tdnn/ # experiment directory +# data directory +# if we set the variable ${dir}, we will store the wav info to this directory +# otherwise, we will store the wav info to vox1 and vox2 directory respectively +dir=data/ +exp_dir=exp/ecapa-tdnn/ # experiment directory + +# vox2 wav path, we must convert the m4a format to wav format +# and store them in the ${PPAUDIO_HOME}/datasets/vox2/wav/ directory +vox2_base_path=${PPAUDIO_HOME}/datasets/vox2/wav/ mkdir -p ${dir} mkdir -p ${exp_dir} if [ $stage -le 0 ]; then # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav - python3 local/data_prepare.py --data-dir ${dir} --augment + python3 local/data_prepare.py \ + --data-dir ${dir} --augment --vox2-base-path ${vox2_base_path} fi if [ $stage -le 1 ]; then diff --git a/paddleaudio/paddleaudio/datasets/__init__.py b/paddleaudio/paddleaudio/datasets/__init__.py index cbf9b3aedbbc11488fd87ed4ba7a46044a84ef25..6f44e97788e89a67f9d17077e75d85f0afa80eaa 100644 --- a/paddleaudio/paddleaudio/datasets/__init__.py +++ b/paddleaudio/paddleaudio/datasets/__init__.py @@ -15,5 +15,5 @@ from .esc50 import ESC50 from .gtzan import GTZAN from .tess import TESS from .urban_sound import UrbanSound8K -from .voxceleb import VoxCeleb1 +from .voxceleb import VoxCeleb from .rirs_noises import OpenRIRNoise diff --git a/paddleaudio/paddleaudio/datasets/voxceleb.py b/paddleaudio/paddleaudio/datasets/voxceleb.py index 4989accb7abe36bed6de59641858b3bc492bb3a4..f8d634f2491938c8559289292126eef88ec8d347 100644 --- a/paddleaudio/paddleaudio/datasets/voxceleb.py +++ b/paddleaudio/paddleaudio/datasets/voxceleb.py @@ -25,10 +25,10 @@ from paddle.io import Dataset from pathos.multiprocessing import Pool from tqdm import tqdm -from .dataset import feat_funcs from ..backends import load as load_audio from ..utils import DATA_HOME from ..utils import decompress +from .dataset import feat_funcs from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.utils.download import download_and_decompress from utils.utility import download @@ -36,10 +36,10 @@ from utils.utility import unpack logger = Log(__name__).getlog() -__all__ = ['VoxCeleb1'] +__all__ = ['VoxCeleb'] -class VoxCeleb1(Dataset): +class VoxCeleb(Dataset): source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/' archieves_audio_dev = [ { @@ -94,8 +94,18 @@ class VoxCeleb1(Dataset): split_ratio: float=0.9, # train split ratio seed: int=0, target_dir: str=None, + vox2_base_path=None, **kwargs): - + """VoxCeleb data prepare and get the specific dataset audio info + + Args: + subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'. + feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'. + random_chunk (bool, optional): random select a duration from audio. Defaults to True. + chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0. + target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None. + vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None. + """ assert subset in self.subsets, \ 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) @@ -106,19 +116,20 @@ class VoxCeleb1(Dataset): self.random_chunk = random_chunk self.chunk_duration = chunk_duration self.split_ratio = split_ratio - self.target_dir = target_dir if target_dir else VoxCeleb1.base_path + self.target_dir = target_dir if target_dir else VoxCeleb.base_path + self.vox2_base_path = vox2_base_path # if we set the target dir, we will change the vox data info data from base path to target dir - VoxCeleb1.csv_path = os.path.join( - target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb1.csv_path - VoxCeleb1.meta_path = os.path.join( + VoxCeleb.csv_path = os.path.join( + target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path + VoxCeleb.meta_path = os.path.join( target_dir, "voxceleb", - 'meta') if target_dir else VoxCeleb1.meta_path - VoxCeleb1.veri_test_file = os.path.join(VoxCeleb1.meta_path, - 'veri_test2.txt') + 'meta') if target_dir else VoxCeleb.meta_path + VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path, + 'veri_test2.txt') # self._data = self._get_data()[:1000] # KP: Small dataset test. self._data = self._get_data() - super(VoxCeleb1, self).__init__() + super(VoxCeleb, self).__init__() # Set up a seed to reproduce training or predicting result. # random.seed(seed) @@ -300,7 +311,14 @@ class VoxCeleb1(Dataset): # get all the train and dev audios file path audio_files = [] speakers = set() - for path in [self.wav_path]: + for path in [self.wav_path, self.vox2_base_path]: + # if vox2 directory is not set and vox2 is not a directory + # we will not process this directory + if not path or not os.path.exists(path): + logger.warning( + f"{path} is an invalid path, please check again, " + "and we will ignore the vox2 base path") + continue for file in glob.glob( os.path.join(path, "**", "*.wav"), recursive=True): spk = file.split('/wav/')[1].split('/')[0] diff --git a/paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_speaker_embedding.py similarity index 99% rename from paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py rename to paddlespeech/vector/exps/ecapa_tdnn/extract_speaker_embedding.py index 78498c6132e21439c05aaa96fe416dd740589c59..44cbd204f9955079b853f9b2b804feeee23ce10e 100644 --- a/paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/extract_speaker_embedding.py @@ -28,6 +28,7 @@ from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() + def extract_audio_embedding(args, config): # stage 0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -83,7 +84,7 @@ if __name__ == "__main__": choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.") - parser.add_argument("--config", + parser.add_argument("--config", default=None, type=str, help="configuration file") diff --git a/paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py b/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py similarity index 96% rename from paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py rename to paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py index 4d85bd62203614f876b965db7acff1b3c6b950d8..01a3506a2a554fc45b853cefbe03f2a7ed41516f 100644 --- a/paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/speaker_verification_cosine.py @@ -17,15 +17,15 @@ import os import numpy as np import paddle -from yacs.config import CfgNode import paddle.nn.functional as F from paddle.io import BatchSampler from paddle.io import DataLoader from tqdm import tqdm +from yacs.config import CfgNode -from paddleaudio.paddleaudio.datasets import VoxCeleb1 -from paddlespeech.s2t.utils.log import Log +from paddleaudio.paddleaudio.datasets import VoxCeleb from paddleaudio.paddleaudio.metric import compute_eer +from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification @@ -33,6 +33,7 @@ from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() + def main(args, config): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -44,7 +45,7 @@ def main(args, config): # stage2: build the speaker verification eval instance with backbone model model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) + backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) # stage3: load the pre-trained model args.load_checkpoint = os.path.abspath( @@ -57,7 +58,7 @@ def main(args, config): logger.info(f'Checkpoint loaded from {args.load_checkpoint}') # stage4: construct the enroll and test dataloader - enroll_dataset = VoxCeleb1( + enroll_dataset = VoxCeleb( subset='enroll', target_dir=args.data_dir, feat_type='melspectrogram', @@ -73,7 +74,7 @@ def main(args, config): num_workers=config.num_workers, return_list=True,) - test_dataset = VoxCeleb1( + test_dataset = VoxCeleb( subset='test', target_dir=args.data_dir, feat_type='melspectrogram', @@ -145,7 +146,7 @@ def main(args, config): labels = [] enrol_ids = [] test_ids = [] - with open(VoxCeleb1.veri_test_file, 'r') as f: + with open(VoxCeleb.veri_test_file, 'r') as f: for line in f.readlines(): label, enrol_id, test_id = line.strip().split(' ') labels.append(int(label)) @@ -171,7 +172,7 @@ if __name__ == "__main__": choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.") - parser.add_argument("--config", + parser.add_argument("--config", default=None, type=str, help="configuration file") diff --git a/paddlespeech/vector/exps/ecapa-tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py similarity index 97% rename from paddlespeech/vector/exps/ecapa-tdnn/train.py rename to paddlespeech/vector/exps/ecapa_tdnn/train.py index 08a4ac1cf04811d43731cb9290cc76811b73bcde..6e6e5ab2417f81c07005581acde16bbda512718e 100644 --- a/paddlespeech/vector/exps/ecapa-tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -20,8 +20,9 @@ from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from yacs.config import CfgNode + from paddleaudio.paddleaudio.compliance.librosa import melspectrogram -from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import waveform_augment @@ -30,13 +31,14 @@ from paddlespeech.vector.io.batch import waveform_collate_fn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import LogSoftmaxWrapper -from paddlespeech.vector.training.scheduler import CyclicLRScheduler from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.scheduler import CyclicLRScheduler from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.utils.time import Timer logger = Log(__name__).getlog() + def main(args, config): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -50,8 +52,8 @@ def main(args, config): # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline # note: some cmd must do in rank==0, so wo will refactor the data prepare code - train_dataset = VoxCeleb1('train', target_dir=args.data_dir) - dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) + train_dataset = VoxCeleb('train', target_dir=args.data_dir) + dev_dataset = VoxCeleb('dev', target_dir=args.data_dir) if args.augment: augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) @@ -63,7 +65,7 @@ def main(args, config): # stage4: build the speaker verification train instance with backbone model model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) + backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) # stage5: build the optimizer, we now only construct the AdamW optimizer lr_schedule = CyclicLRScheduler( @@ -263,7 +265,7 @@ if __name__ == "__main__": choices=['cpu', 'gpu'], default="cpu", help="Select which device to train model, defaults to gpu.") - parser.add_argument("--config", + parser.add_argument("--config", default=None, type=str, help="configuration file") diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py index 7631297801d266ec265b4e6f5fc40d2c37058638..1b9d1fbd897e6b7dbc52378f506b4dc92d34336d 100644 --- a/paddlespeech/vector/io/augment.py +++ b/paddlespeech/vector/io/augment.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# this is modified from https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py import math import os from typing import List diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py index 4c960e117f5d952a214631d58b4c9a023af33c26..0e7287cd3614d8964941f6d14179e0ce7f3c4d71 100644 --- a/paddlespeech/vector/models/ecapa_tdnn.py +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -19,16 +19,6 @@ import paddle.nn.functional as F def length_to_mask(length, max_len=None, dtype=None): - """_summary_ - - Args: - length (_type_): _description_ - max_len (_type_, optional): _description_. Defaults to None. - dtype (_type_, optional): _description_. Defaults to None. - - Returns: - _type_: _description_ - """ assert len(length.shape) == 1 if max_len is None: @@ -60,15 +50,15 @@ class Conv1d(nn.Layer): """_summary_ Args: - in_channels (_type_): _description_ - out_channels (_type_): _description_ - kernel_size (_type_): _description_ - stride (int, optional): _description_. Defaults to 1. - padding (str, optional): _description_. Defaults to "same". - dilation (int, optional): _description_. Defaults to 1. - groups (int, optional): _description_. Defaults to 1. - bias (bool, optional): _description_. Defaults to True. - padding_mode (str, optional): _description_. Defaults to "reflect". + in_channels (int): intput channel or input data dimensions + out_channels (int): output channel or output data dimensions + kernel_size (int): kernel size of 1-d convolution + stride (int, optional): strid in 1-d convolution . Defaults to 1. + padding (str, optional): padding value. Defaults to "same". + dilation (int, optional): dilation in 1-d convolution. Defaults to 1. + groups (int, optional): groups in 1-d convolution. Defaults to 1. + bias (bool, optional): bias in 1-d convolution . Defaults to True. + padding_mode (str, optional): padding mode. Defaults to "reflect". """ super().__init__() @@ -89,17 +79,6 @@ class Conv1d(nn.Layer): bias_attr=bias, ) def forward(self, x): - """_summary_ - - Args: - x (_type_): _description_ - - Raises: - ValueError: _description_ - - Returns: - _type_: _description_ - """ if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) @@ -109,17 +88,6 @@ class Conv1d(nn.Layer): return self.conv(x) def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): - """_summary_ - - Args: - x (_type_): _description_ - kernel_size (int): _description_ - dilation (int): _description_ - stride (int): _description_ - - Returns: - _type_: _description_ - """ L_in = x.shape[-1] # Detecting input shape padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding @@ -133,17 +101,6 @@ class Conv1d(nn.Layer): stride: int, kernel_size: int, dilation: int): - """_summary_ - - Args: - L_in (int): _description_ - stride (int): _description_ - kernel_size (int): _description_ - dilation (int): _description_ - - Returns: - _type_: _description_ - """ if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation @@ -220,8 +177,8 @@ class Res2NetBlock(nn.Layer): Args: in_channels (int): input channels or input dimensions out_channels (int): output channels or output dimensions - scale (int, optional): _description_. Defaults to 8. - dilation (int, optional): _description_. Defaults to 1. + scale (int, optional): scale in res2net bolck. Defaults to 8. + dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1. """ super().__init__() assert in_channels % scale == 0 @@ -358,15 +315,16 @@ class SERes2NetBlock(nn.Layer): dilation=1, activation=nn.ReLU, ): """Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model - + The paper is refered "Squeeze-and-Excitation Networks" + whose url is: https://arxiv.org/pdf/1709.01507.pdf Args: in_channels (int): input channels or input data dimensions - out_channels (_type_): _description_ - res2net_scale (int, optional): _description_. Defaults to 8. - se_channels (int, optional): _description_. Defaults to 128. - kernel_size (int, optional): _description_. Defaults to 1. - dilation (int, optional): _description_. Defaults to 1. - activation (_type_, optional): _description_. Defaults to nn.ReLU. + out_channels (int): output channels or output data dimensions + res2net_scale (int, optional): scale in the res2net block. Defaults to 8. + se_channels (int, optional): embedding dimensions of res2net block. Defaults to 128. + kernel_size (int, optional): kernel size of 1-d convolution in TDNN block. Defaults to 1. + dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1. + activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU. """ super().__init__() self.out_channels = out_channels @@ -419,7 +377,21 @@ class EcapaTdnn(nn.Layer): res2net_scale=8, se_channels=128, global_context=True, ): - + """Implementation of ECAPA-TDNN backbone model network + The paper is refered as "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" + whose url is: https://arxiv.org/abs/2005.07143 + Args: + input_size (_type_): input fature dimension + lin_neurons (int, optional): speaker embedding size. Defaults to 192. + activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU. + channels (list, optional): inter embedding dimension. Defaults to [512, 512, 512, 512, 1536]. + kernel_sizes (list, optional): kernel size of 1-d convolution in TDNN block . Defaults to [5, 3, 3, 3, 1]. + dilations (list, optional): dilations of 1-d convolution in TDNN block. Defaults to [1, 2, 3, 4, 1]. + attention_channels (int, optional): attention dimensions. Defaults to 128. + res2net_scale (int, optional): scale value in res2net. Defaults to 8. + se_channels (int, optional): dimensions of squeeze-excitation block. Defaults to 128. + global_context (bool, optional): global context flag. Defaults to True. + """ super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations)