diff --git a/examples/voxceleb/sv0/local/speaker_verification_cosine.py b/examples/voxceleb/sv0/local/speaker_verification_cosine.py index 1959e85c5124f2f06af64a37d7115224b8b050fb..b0adcf6649dd4e7b95c3c9a8cc75ef835509648f 100644 --- a/examples/voxceleb/sv0/local/speaker_verification_cosine.py +++ b/examples/voxceleb/sv0/local/speaker_verification_cosine.py @@ -23,9 +23,13 @@ from paddle.io import DataLoader from tqdm import tqdm from paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.metrics import compute_eer +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): @@ -67,9 +71,19 @@ def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): return {'ids': ids, 'feats': feats, 'lengths': lengths} +# feat configuration +cpu_feat_conf = { + 'n_mels': 80, + 'window_size': 400, #ms + 'hop_length': 160, #ms +} + + def main(args): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) + # set the random seed, it is a must for multiprocess training + seed_everything(args.seed) # stage1: build the dnn backbone model network ##"channels": [1024, 1024, 1024, 1024, 3072], @@ -95,19 +109,18 @@ def main(args): state_dict = paddle.load( os.path.join(args.load_checkpoint, 'model.pdparams')) model.set_state_dict(state_dict) - print(f'Checkpoint loaded from {args.load_checkpoint}') + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') # stage4: construct the enroll and test dataloader enrol_ds = VoxCeleb1( subset='enrol', + target_dir=args.data_dir, feat_type='melspectrogram', random_chunk=False, - n_mels=80, - window_size=400, - hop_length=160) + **cpu_feat_conf) enrol_sampler = BatchSampler( enrol_ds, batch_size=args.batch_size, - shuffle=True) # Shuffle to make embedding normalization more robust. + shuffle=False) # Shuffle to make embedding normalization more robust. enrol_loader = DataLoader(enrol_ds, batch_sampler=enrol_sampler, collate_fn=lambda x: feature_normalize( @@ -117,14 +130,13 @@ def main(args): test_ds = VoxCeleb1( subset='test', + target_dir=args.data_dir, feat_type='melspectrogram', random_chunk=False, - n_mels=80, - window_size=400, - hop_length=160) + **cpu_feat_conf) test_sampler = BatchSampler( - test_ds, batch_size=args.batch_size, shuffle=True) + test_ds, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_sampler=test_sampler, collate_fn=lambda x: feature_normalize( @@ -136,10 +148,10 @@ def main(args): # stage7: global embedding norm to imporve the performance if args.global_embedding_norm: - embedding_mean = None - embedding_std = None - mean_norm = args.embedding_mean_norm - std_norm = args.embedding_std_norm + global_embedding_mean = None + global_embedding_std = None + mean_norm_flag = args.embedding_mean_norm + std_norm_flag = args.embedding_std_norm batch_count = 0 # stage8: Compute embeddings of audios in enrol and test dataset from model. @@ -147,7 +159,7 @@ def main(args): # Run multi times to make embedding normalization more stable. for i in range(2): for dl in [enrol_loader, test_loader]: - print( + logger.info( f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' ) with paddle.no_grad(): @@ -162,20 +174,24 @@ def main(args): # Global embedding normalization. if args.global_embedding_norm: batch_count += 1 - mean = embeddings.mean(axis=0) if mean_norm else 0 - std = embeddings.std(axis=0) if std_norm else 1 + current_mean = embeddings.mean( + axis=0) if mean_norm_flag else 0 + current_std = embeddings.std( + axis=0) if std_norm_flag else 1 # Update global mean and std. - if embedding_mean is None and embedding_std is None: - embedding_mean, embedding_std = mean, std + if global_embedding_mean is None and global_embedding_std is None: + global_embedding_mean, global_embedding_std = current_mean, current_std else: weight = 1 / batch_count # Weight decay by batches. - embedding_mean = (1 - weight - ) * embedding_mean + weight * mean - embedding_std = (1 - weight - ) * embedding_std + weight * std + global_embedding_mean = ( + 1 - weight + ) * global_embedding_mean + weight * current_mean + global_embedding_std = ( + 1 - weight + ) * global_embedding_std + weight * current_std # Apply global embedding normalization. - embeddings = ( - embeddings - embedding_mean) / embedding_std + embeddings = (embeddings - global_embedding_mean + ) / global_embedding_std # Update embedding dict. id2embedding.update(dict(zip(ids, embeddings))) @@ -198,7 +214,7 @@ def main(args): ]) # (N, emb_size) scores = cos_sim_func(enrol_embeddings, test_embeddings) EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) - print( + logger.info( f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' ) @@ -210,10 +226,18 @@ if __name__ == "__main__": choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.") + parser.add_argument("--seed", + default=0, + type=int, + help="random seed for paddle, numpy and python random package") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") parser.add_argument("--batch-size", type=int, default=16, - help="Total examples' number in batch for training.") + help="Total examples' number in batch for extract the embedding.") parser.add_argument("--num-workers", type=int, default=0, diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py index 4eabf94c09f36e7995f1a82812ee19803917e88e..745d5eab3b6fc8dc994dceac6617fc24705aaf60 100644 --- a/examples/voxceleb/sv0/local/train.py +++ b/examples/voxceleb/sv0/local/train.py @@ -22,6 +22,9 @@ from paddle.io import DistributedBatchSampler from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.features.core import melspectrogram +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.augment import build_augment_pipeline +from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import waveform_collate_fn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn @@ -29,8 +32,11 @@ from paddlespeech.vector.modules.loss import AdditiveAngularMargin from paddlespeech.vector.modules.loss import LogSoftmaxWrapper from paddlespeech.vector.modules.lr import CyclicLRScheduler from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.utils.time import Timer +logger = Log(__name__).getlog() + # feat configuration cpu_feat_conf = { 'n_mels': 80, @@ -47,12 +53,19 @@ def main(args): paddle.distributed.init_parallel_env() nranks = paddle.distributed.get_world_size() local_rank = paddle.distributed.get_rank() + # set the random seed, it is a must for multiprocess training + seed_everything(args.seed) - # stage2: data prepare - # note: some cmd must do in rank==0 + # stage2: data prepare, such vox1 and vox2 data, and augment data and pipline + # note: some cmd must do in rank==0, so wo will refactor the data prepare code train_ds = VoxCeleb1('train', target_dir=args.data_dir) dev_ds = VoxCeleb1('dev', target_dir=args.data_dir) + if args.augment: + augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + else: + augment_pipeline = [] + # stage3: build the dnn backbone model network #"channels": [1024, 1024, 1024, 1024, 3072], model_conf = { @@ -83,7 +96,7 @@ def main(args): # if pre-trained model exists, start epoch confirmed by the pre-trained model start_epoch = 0 if args.load_checkpoint: - print("load the check point") + logger.info("load the check point") args.load_checkpoint = os.path.abspath( os.path.expanduser(args.load_checkpoint)) try: @@ -97,14 +110,14 @@ def main(args): os.path.join(args.load_checkpoint, 'model.pdopt')) optimizer.set_state_dict(state_dict) if local_rank == 0: - print(f'Checkpoint loaded from {args.load_checkpoint}') + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') except FileExistsError: if local_rank == 0: - print('Train from scratch.') + logger.info('Train from scratch.') try: start_epoch = int(args.load_checkpoint[-1]) - print(f'Restore training from epoch {start_epoch}.') + logger.info(f'Restore training from epoch {start_epoch}.') except ValueError: pass @@ -137,7 +150,10 @@ def main(args): waveforms, labels = batch['waveforms'], batch['labels'] # stage 9-2: audio sample augment method, which is done on the audio sample point - # todo + if len(augment_pipeline) != 0: + waveforms = waveform_augment(waveforms, augment_pipeline) + labels = paddle.concat( + [labels for i in range(len(augment_pipeline) + 1)]) # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram feats = [] @@ -185,7 +201,7 @@ def main(args): print_msg += ' acc={:.4f}'.format(avg_acc) print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( lr, timer.timing, timer.eta) - print(print_msg) + logger.info(print_msg) avg_loss = 0 num_corrects = 0 @@ -217,7 +233,7 @@ def main(args): num_samples = 0 # stage 9-13: evaluation the valid dataset batch data - print('Evaluate on validation dataset') + logger.info('Evaluate on validation dataset') with paddle.no_grad(): for batch_idx, batch in enumerate(dev_loader): waveforms, labels = batch['waveforms'], batch['labels'] @@ -238,12 +254,12 @@ def main(args): print_msg = '[Evaluation result]' print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples) - print(print_msg) + logger.info(print_msg) # stage 9-14: Save model parameters save_dir = os.path.join(args.checkpoint_dir, 'epoch_{}'.format(epoch)) - print('Saving model checkpoint to {}'.format(save_dir)) + logger.info('Saving model checkpoint to {}'.format(save_dir)) paddle.save(model.state_dict(), os.path.join(save_dir, 'model.pdparams')) paddle.save(optimizer.state_dict(), @@ -260,6 +276,10 @@ if __name__ == "__main__": choices=['cpu', 'gpu'], default="cpu", help="Select which device to train model, defaults to gpu.") + parser.add_argument("--seed", + default=0, + type=int, + help="random seed for paddle, numpy and python random package") parser.add_argument("--data-dir", default="./data/", type=str, @@ -295,6 +315,10 @@ if __name__ == "__main__": type=str, default='./checkpoint', help="Directory to save model checkpoints.") + parser.add_argument("--augment", + action="store_true", + default=False, + help="Apply audio augments.") args = parser.parse_args() # yapf: enable diff --git a/paddleaudio/datasets/rirs_noises.py b/paddleaudio/datasets/rirs_noises.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9e7f09d104282f12d265227f6ee28b19046cc7 --- /dev/null +++ b/paddleaudio/datasets/rirs_noises.py @@ -0,0 +1,207 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import csv +import glob +import os +import random +from typing import Dict +from typing import List +from typing import Tuple + +from paddle.io import Dataset +from tqdm import tqdm + +from paddleaudio.backends import load as load_audio +from paddleaudio.backends import save_wav +from paddleaudio.datasets.dataset import feat_funcs +from paddleaudio.utils import DATA_HOME +from paddleaudio.utils import decompress +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.utils.download import download_and_decompress + +logger = Log(__name__).getlog() + +__all__ = ['OpenRIRNoise'] + + +class OpenRIRNoise(Dataset): + archieves = [ + { + 'url': 'http://www.openslr.org/resources/28/rirs_noises.zip', + 'md5': 'e6f48e257286e05de56413b4779d8ffb', + }, + ] + + sample_rate = 16000 + meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav')) + base_path = os.path.join(DATA_HOME, 'open_rir_noise') + wav_path = os.path.join(base_path, 'RIRS_NOISES') + csv_path = os.path.join(base_path, 'csv') + subsets = ['rir', 'noise'] + + def __init__(self, + subset: str='rir', + feat_type: str='raw', + target_dir=None, + random_chunk: bool=True, + chunk_duration: float=3.0, + seed: int=0, + **kwargs): + + assert subset in self.subsets, \ + 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) + + self.subset = subset + self.feat_type = feat_type + self.feat_config = kwargs + self.random_chunk = random_chunk + self.chunk_duration = chunk_duration + + self.csv_path = os.path.join(target_dir, "open_rir_noise", + "csv") if target_dir else self.csv_path + self._data = self._get_data() + super(OpenRIRNoise, self).__init__() + + # Set up a seed to reproduce training or predicting result. + # random.seed(seed) + + def _get_data(self): + # Download audio files. + logger.info(f"rirs noises base path: {self.base_path}") + if not os.path.isdir(self.base_path): + download_and_decompress( + self.archieves, self.base_path, decompress=True) + else: + logger.info( + f"{self.base_path} already exists, we will not download and decompress again" + ) + + # Data preparation. + logger.info(f"prepare the csv to {self.csv_path}") + if not os.path.isdir(self.csv_path): + os.makedirs(self.csv_path) + self.prepare_data() + + data = [] + with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: + for line in rf.readlines()[1:]: + audio_id, duration, wav = line.strip().split(',') + data.append(self.meta_info(audio_id, float(duration), wav)) + + random.shuffle(data) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + return record + + @staticmethod + def _get_chunks(seg_dur, audio_id, audio_duration): + num_chunks = int(audio_duration / seg_dur) # all in milliseconds + + chunk_lst = [ + audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) + for i in range(num_chunks) + ] + return chunk_lst + + def _get_audio_info(self, wav_file: str, + split_chunks: bool) -> List[List[str]]: + waveform, sr = load_audio(wav_file) + audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0] + audio_duration = waveform.shape[0] / sr + + ret = [] + if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds. + uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id, + audio_duration) + + for idx, chunk in enumerate(uniq_chunks_list): + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + new_wav_file = os.path.join(self.base_path, + audio_id + f'_chunk_{idx+1:02}.wav') + save_wav(waveform[start_sample:end_sample], sr, new_wav_file) + # id, duration, new_wav + ret.append([chunk, self.chunk_duration, new_wav_file]) + else: # Keep whole audio. + ret.append([audio_id, audio_duration, wav_file]) + return ret + + def generate_csv(self, + wav_files: List[str], + output_file: str, + split_chunks: bool=True): + logger.info(f'Generating csv: {output_file}') + header = ["id", "duration", "wav"] + + infos = list( + tqdm( + map(self._get_audio_info, wav_files, [split_chunks] * len( + wav_files)), + total=len(wav_files))) + + csv_lines = [] + for info in infos: + csv_lines.extend(info) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + def prepare_data(self): + rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises", + "rir_list") + rir_files = [] + with open(rir_list, 'r') as f: + for line in f.readlines(): + rir_file = line.strip().split(' ')[-1] + rir_files.append(os.path.join(self.base_path, rir_file)) + + noise_list = os.path.join(self.wav_path, "pointsource_noises", + "noise_list") + noise_files = [] + with open(noise_list, 'r') as f: + for line in f.readlines(): + noise_file = line.strip().split(' ')[-1] + noise_files.append(os.path.join(self.base_path, noise_file)) + + self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv')) + self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv')) + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/paddleaudio/datasets/voxceleb.py b/paddleaudio/datasets/voxceleb.py index 760db72169f0ff16f51d52372dc6d7d618b760c8..28f6dfc66b9ddd4947e0addfb501f962b7bb4fbc 100644 --- a/paddleaudio/datasets/voxceleb.py +++ b/paddleaudio/datasets/voxceleb.py @@ -29,9 +29,12 @@ from paddleaudio.datasets.dataset import feat_funcs from paddleaudio.utils import DATA_HOME from paddleaudio.utils import decompress from paddleaudio.utils import download_and_decompress +from paddlespeech.s2t.utils.log import Log from utils.utility import download from utils.utility import unpack +logger = Log(__name__).getlog() + __all__ = ['VoxCeleb1'] @@ -121,9 +124,9 @@ class VoxCeleb1(Dataset): # Download audio files. # We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir # so, we check the vox1/wav dir status - print("wav base path: {}".format(self.wav_path)) + logger.info(f"wav base path: {self.wav_path}") if not os.path.isdir(self.wav_path): - print("start to download the voxceleb1 dataset") + logger.info(f"start to download the voxceleb1 dataset") download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip self.archieves_audio_dev, self.base_path, @@ -135,7 +138,7 @@ class VoxCeleb1(Dataset): # Download all parts and concatenate the files into one zip file. dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip') - print(f'Concatenating all parts to: {dev_zipfile}') + logger.info(f'Concatenating all parts to: {dev_zipfile}') os.system( f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}' ) @@ -154,6 +157,9 @@ class VoxCeleb1(Dataset): self.prepare_data() data = [] + logger.info( + f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}" + ) with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: for line in rf.readlines()[1:]: audio_id, duration, wav, start, stop, spk_id = line.strip( @@ -246,7 +252,7 @@ class VoxCeleb1(Dataset): wav_files: List[str], output_file: str, split_chunks: bool=True): - print(f'Generating csv: {output_file}') + logger.info(f'Generating csv: {output_file}') header = ["id", "duration", "wav", "start", "stop", "spk_id"] with Pool(64) as p: @@ -269,7 +275,7 @@ class VoxCeleb1(Dataset): def prepare_data(self): # Audio of speakers in veri_test_file should not be included in training set. - print("start to prepare the data csv file") + logger.info("start to prepare the data csv file") enrol_files = set() test_files = set() # get the enroll and test audio file path @@ -299,7 +305,7 @@ class VoxCeleb1(Dataset): speakers.add(spk) audio_files.append(file) - print("start to generate the {}".format( + logger.info("start to generate the {}".format( os.path.join(self.meta_path, 'spk_id2label.txt'))) # encode the train and dev speakers label to spk_id2label.txt with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f: diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bbc8a99cbcf3b7570f61693b7dce1c4ed40001 --- /dev/null +++ b/paddlespeech/vector/io/augment.py @@ -0,0 +1,899 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from typing import List + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleaudio.backends import load as load_audio +from paddleaudio.datasets.rirs_noises import OpenRIRNoise +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.signal_processing import compute_amplitude +from paddlespeech.vector.io.signal_processing import convolve1d +from paddlespeech.vector.io.signal_processing import dB_to_amplitude +from paddlespeech.vector.io.signal_processing import notch_filter +from paddlespeech.vector.io.signal_processing import reverberate + +logger = Log(__name__).getlog() + + +# TODO: Complete type-hint and doc string. +class DropFreq(nn.Layer): + def __init__( + self, + drop_freq_low=1e-14, + drop_freq_high=1, + drop_count_low=1, + drop_count_high=2, + drop_width=0.05, + drop_prob=1, ): + super(DropFreq, self).__init__() + self.drop_freq_low = drop_freq_low + self.drop_freq_high = drop_freq_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_width = drop_width + self.drop_prob = drop_prob + + def forward(self, waveforms): + # Don't drop (return early) 1-`drop_prob` portion of the batches + dropped_waveform = waveforms.clone() + if paddle.rand([1]) > self.drop_prob: + return dropped_waveform + + # Add channels dimension + if len(waveforms.shape) == 2: + dropped_waveform = dropped_waveform.unsqueeze(-1) + + # Pick number of frequencies to drop + drop_count = paddle.randint( + low=self.drop_count_low, high=self.drop_count_high + 1, shape=[1]) + + # Pick a frequency to drop + drop_range = self.drop_freq_high - self.drop_freq_low + drop_frequency = ( + paddle.rand([drop_count]) * drop_range + self.drop_freq_low) + + # Filter parameters + filter_length = 101 + pad = filter_length // 2 + + # Start with delta function + drop_filter = paddle.zeros([1, filter_length, 1]) + drop_filter[0, pad, 0] = 1 + + # Subtract each frequency + for frequency in drop_frequency: + notch_kernel = notch_filter(frequency, filter_length, + self.drop_width) + drop_filter = convolve1d(drop_filter, notch_kernel, pad) + + # Apply filter + dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad) + + # Remove channels dimension if added + return dropped_waveform.squeeze(-1) + + +class DropChunk(nn.Layer): + def __init__( + self, + drop_length_low=100, + drop_length_high=1000, + drop_count_low=1, + drop_count_high=10, + drop_start=0, + drop_end=None, + drop_prob=1, + noise_factor=0.0, ): + super(DropChunk, self).__init__() + self.drop_length_low = drop_length_low + self.drop_length_high = drop_length_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_start = drop_start + self.drop_end = drop_end + self.drop_prob = drop_prob + self.noise_factor = noise_factor + + # Validate low < high + if drop_length_low > drop_length_high: + raise ValueError("Low limit must not be more than high limit") + if drop_count_low > drop_count_high: + raise ValueError("Low limit must not be more than high limit") + + # Make sure the length doesn't exceed end - start + if drop_end is not None and drop_end >= 0: + if drop_start > drop_end: + raise ValueError("Low limit must not be more than high limit") + + drop_range = drop_end - drop_start + self.drop_length_low = min(drop_length_low, drop_range) + self.drop_length_high = min(drop_length_high, drop_range) + + def forward(self, waveforms, lengths): + # Reading input list + lengths = (lengths * waveforms.shape[1]).astype('int64') + batch_size = waveforms.shape[0] + dropped_waveform = waveforms.clone() + + # Don't drop (return early) 1-`drop_prob` portion of the batches + if paddle.rand([1]) > self.drop_prob: + return dropped_waveform + + # Store original amplitude for computing white noise amplitude + clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1)) + + # Pick a number of times to drop + drop_times = paddle.randint( + low=self.drop_count_low, + high=self.drop_count_high + 1, + shape=[batch_size], ) + + # Iterate batch to set mask + for i in range(batch_size): + if drop_times[i] == 0: + continue + + # Pick lengths + length = paddle.randint( + low=self.drop_length_low, + high=self.drop_length_high + 1, + shape=[drop_times[i]], ) + + # Compute range of starting locations + start_min = self.drop_start + if start_min < 0: + start_min += lengths[i] + start_max = self.drop_end + if start_max is None: + start_max = lengths[i] + if start_max < 0: + start_max += lengths[i] + start_max = max(0, start_max - length.max()) + + # Pick starting locations + start = paddle.randint( + low=start_min, + high=start_max + 1, + shape=[drop_times[i]], ) + + end = start + length + + # Update waveform + if not self.noise_factor: + for j in range(drop_times[i]): + dropped_waveform[i, start[j]:end[j]] = 0.0 + else: + # Uniform distribution of -2 to +2 * avg amplitude should + # preserve the average for normalization + noise_max = 2 * clean_amplitude[i] * self.noise_factor + for j in range(drop_times[i]): + # zero-center the noise distribution + noise_vec = paddle.rand([length[j]], dtype='float32') + + noise_vec = 2 * noise_max * noise_vec - noise_max + dropped_waveform[i, int(start[j]):int(end[j])] = noise_vec + + return dropped_waveform + + +class Resample(nn.Layer): + def __init__( + self, + orig_freq=16000, + new_freq=16000, + lowpass_filter_width=6, ): + super(Resample, self).__init__() + self.orig_freq = orig_freq + self.new_freq = new_freq + self.lowpass_filter_width = lowpass_filter_width + + # Compute rate for striding + self._compute_strides() + assert self.orig_freq % self.conv_stride == 0 + assert self.new_freq % self.conv_transpose_stride == 0 + + def _compute_strides(self): + # Compute new unit based on ratio of in/out frequencies + base_freq = math.gcd(self.orig_freq, self.new_freq) + input_samples_in_unit = self.orig_freq // base_freq + self.output_samples = self.new_freq // base_freq + + # Store the appropriate stride based on the new units + self.conv_stride = input_samples_in_unit + self.conv_transpose_stride = self.output_samples + + def forward(self, waveforms): + if not hasattr(self, "first_indices"): + self._indices_and_weights(waveforms) + + # Don't do anything if the frequencies are the same + if self.orig_freq == self.new_freq: + return waveforms + + unsqueezed = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(1) + unsqueezed = True + elif len(waveforms.shape) == 3: + waveforms = waveforms.transpose([0, 2, 1]) + else: + raise ValueError("Input must be 2 or 3 dimensions") + + # Do resampling + resampled_waveform = self._perform_resample(waveforms) + + if unsqueezed: + resampled_waveform = resampled_waveform.squeeze(1) + else: + resampled_waveform = resampled_waveform.transpose([0, 2, 1]) + + return resampled_waveform + + def _perform_resample(self, waveforms): + # Compute output size and initialize + batch_size, num_channels, wave_len = waveforms.shape + window_size = self.weights.shape[1] + tot_output_samp = self._output_samples(wave_len) + resampled_waveform = paddle.zeros((batch_size, num_channels, + tot_output_samp)) + + # eye size: (num_channels, num_channels, 1) + eye = paddle.eye(num_channels).unsqueeze(2) + + # Iterate over the phases in the polyphase filter + for i in range(self.first_indices.shape[0]): + wave_to_conv = waveforms + first_index = int(self.first_indices[i].item()) + if first_index >= 0: + # trim the signal as the filter will not be applied + # before the first_index + wave_to_conv = wave_to_conv[:, :, first_index:] + + # pad the right of the signal to allow partial convolutions + # meaning compute values for partial windows (e.g. end of the + # window is outside the signal length) + max_index = (tot_output_samp - 1) // self.output_samples + end_index = max_index * self.conv_stride + window_size + current_wave_len = wave_len - first_index + right_padding = max(0, end_index + 1 - current_wave_len) + left_padding = max(0, -first_index) + wave_to_conv = paddle.nn.functional.pad( + wave_to_conv, [left_padding, right_padding], data_format='NCL') + conv_wave = paddle.nn.functional.conv1d( + x=wave_to_conv, + # weight=self.weights[i].repeat(num_channels, 1, 1), + weight=self.weights[i].expand((num_channels, 1, -1)), + stride=self.conv_stride, + groups=num_channels, ) + + # we want conv_wave[:, i] to be at + # output[:, i + n*conv_transpose_stride] + dilated_conv_wave = paddle.nn.functional.conv1d_transpose( + conv_wave, eye, stride=self.conv_transpose_stride) + + # pad dilated_conv_wave so it reaches the output length if needed. + left_padding = i + previous_padding = left_padding + dilated_conv_wave.shape[-1] + right_padding = max(0, tot_output_samp - previous_padding) + dilated_conv_wave = paddle.nn.functional.pad( + dilated_conv_wave, [left_padding, right_padding], + data_format='NCL') + dilated_conv_wave = dilated_conv_wave[:, :, :tot_output_samp] + + resampled_waveform += dilated_conv_wave + + return resampled_waveform + + def _output_samples(self, input_num_samp): + samp_in = int(self.orig_freq) + samp_out = int(self.new_freq) + + tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out) + ticks_per_input_period = tick_freq // samp_in + + # work out the number of ticks in the time interval + # [ 0, input_num_samp/samp_in ). + interval_length = input_num_samp * ticks_per_input_period + if interval_length <= 0: + return 0 + ticks_per_output_period = tick_freq // samp_out + + # Get the last output-sample in the closed interval, + # i.e. replacing [ ) with [ ]. Note: integer division rounds down. + # See http://en.wikipedia.org/wiki/Interval_(mathematics) for an + # explanation of the notation. + last_output_samp = interval_length // ticks_per_output_period + + # We need the last output-sample in the open interval, so if it + # takes us to the end of the interval exactly, subtract one. + if last_output_samp * ticks_per_output_period == interval_length: + last_output_samp -= 1 + + # First output-sample index is zero, so the number of output samples + # is the last output-sample plus one. + num_output_samp = last_output_samp + 1 + + return num_output_samp + + def _indices_and_weights(self, waveforms): + # Lowpass filter frequency depends on smaller of two frequencies + min_freq = min(self.orig_freq, self.new_freq) + lowpass_cutoff = 0.99 * 0.5 * min_freq + + assert lowpass_cutoff * 2 <= min_freq + window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff) + + assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2 + output_t = paddle.arange(start=0.0, end=self.output_samples) + output_t /= self.new_freq + min_t = output_t - window_width + max_t = output_t + window_width + + min_input_index = paddle.ceil(min_t * self.orig_freq) + max_input_index = paddle.floor(max_t * self.orig_freq) + num_indices = max_input_index - min_input_index + 1 + + max_weight_width = num_indices.max() + j = paddle.arange(max_weight_width, dtype='float32') + input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0) + delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1) + + weights = paddle.zeros_like(delta_t) + inside_window_indices = delta_t.abs().less_than( + paddle.to_tensor(window_width)) + + # raised-cosine (Hanning) window with width `window_width` + weights[inside_window_indices] = 0.5 * (1 + paddle.cos( + 2 * math.pi * lowpass_cutoff / self.lowpass_filter_width * + delta_t.masked_select(inside_window_indices))) + + t_eq_zero_indices = delta_t.equal(paddle.zeros_like(delta_t)) + t_not_eq_zero_indices = delta_t.not_equal(paddle.zeros_like(delta_t)) + + # sinc filter function + weights = paddle.where( + t_not_eq_zero_indices, + weights * paddle.sin(2 * math.pi * lowpass_cutoff * delta_t) / + (math.pi * delta_t), weights) + + # limit of the function at t = 0 + weights = paddle.where(t_eq_zero_indices, weights * 2 * lowpass_cutoff, + weights) + + # size (output_samples, max_weight_width) + weights /= self.orig_freq + + self.first_indices = min_input_index + self.weights = weights + + +class SpeedPerturb(nn.Layer): + def __init__( + self, + orig_freq, + speeds=[90, 100, 110], + perturb_prob=1.0, ): + super(SpeedPerturb, self).__init__() + self.orig_freq = orig_freq + self.speeds = speeds + self.perturb_prob = perturb_prob + + # Initialize index of perturbation + self.samp_index = 0 + + # Initialize resamplers + self.resamplers = [] + for speed in self.speeds: + config = { + "orig_freq": self.orig_freq, + "new_freq": self.orig_freq * speed // 100, + } + self.resamplers.append(Resample(**config)) + + def forward(self, waveform): + # Don't perturb (return early) 1-`perturb_prob` portion of the batches + if paddle.rand([1]) > self.perturb_prob: + return waveform.clone() + + # Perform a random perturbation + self.samp_index = paddle.randint(len(self.speeds), shape=[1]).item() + perturbed_waveform = self.resamplers[self.samp_index](waveform) + + return perturbed_waveform + + +class AddNoise(nn.Layer): + def __init__( + self, + noise_dataset=None, # None for white noise + num_workers=0, + snr_low=0, + snr_high=0, + mix_prob=1.0, + start_index=None, + normalize=False, ): + super(AddNoise, self).__init__() + + self.num_workers = num_workers + self.snr_low = snr_low + self.snr_high = snr_high + self.mix_prob = mix_prob + self.start_index = start_index + self.normalize = normalize + self.noise_dataset = noise_dataset + self.noise_dataloader = None + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Copy clean waveform to initialize noisy waveform + noisy_waveform = waveforms.clone() + lengths = (lengths * waveforms.shape[1]).astype('int64').unsqueeze(1) + + # Don't add noise (return early) 1-`mix_prob` portion of the batches + if paddle.rand([1]) > self.mix_prob: + return noisy_waveform + + # Compute the average amplitude of the clean waveforms + clean_amplitude = compute_amplitude(waveforms, lengths) + + # Pick an SNR and use it to compute the mixture amplitude factors + SNR = paddle.rand((len(waveforms), 1)) + SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low + noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1) + new_noise_amplitude = noise_amplitude_factor * clean_amplitude + + # Scale clean signal appropriately + noisy_waveform *= 1 - noise_amplitude_factor + + # Loop through clean samples and create mixture + if self.noise_dataset is None: + white_noise = paddle.normal(shape=waveforms.shape) + noisy_waveform += new_noise_amplitude * white_noise + else: + tensor_length = waveforms.shape[1] + noise_waveform, noise_length = self._load_noise( + lengths, + tensor_length, ) + + # Rescale and add + noise_amplitude = compute_amplitude(noise_waveform, noise_length) + noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14) + noisy_waveform += noise_waveform + + # Normalizing to prevent clipping + if self.normalize: + abs_max, _ = paddle.max( + paddle.abs(noisy_waveform), axis=1, keepdim=True) + noisy_waveform = noisy_waveform / abs_max.clip(min=1.0) + + return noisy_waveform + + def _load_noise(self, lengths, max_length): + """ + Load a batch of noises + + args + lengths(Paddle.Tensor): Num samples of waveforms with shape (N, 1). + max_length(int): Width of a batch. + """ + lengths = lengths.squeeze(1) + batch_size = len(lengths) + + # Load a noise batch + if self.noise_dataloader is None: + + def noise_collate_fn(batch): + def pad(x, target_length, mode='constant', **kwargs): + x = np.asarray(x) + w = target_length - x.shape[0] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' + return np.pad(x, [0, w], mode=mode, **kwargs) + + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[0] for item in batch]) + waveforms = list( + map(lambda x: pad(x, max(max_length, lengths.max().item())), + [item['feat'] for item in batch])) + waveforms = np.stack(waveforms) + return {'ids': ids, 'feats': waveforms, 'lengths': lengths} + + # Create noise data loader. + self.noise_dataloader = paddle.io.DataLoader( + self.noise_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + collate_fn=noise_collate_fn, + return_list=True, ) + self.noise_data = iter(self.noise_dataloader) + + noise_batch, noise_len = self._load_noise_batch_of_size(batch_size) + + # Select a random starting location in the waveform + start_index = self.start_index + if self.start_index is None: + start_index = 0 + max_chop = (noise_len - lengths).min().clip(min=1) + start_index = paddle.randint(high=max_chop, shape=[1]) + + # Truncate noise_batch to max_length + noise_batch = noise_batch[:, start_index:start_index + max_length] + noise_len = (noise_len - start_index).clip(max=max_length).unsqueeze(1) + return noise_batch, noise_len + + def _load_noise_batch_of_size(self, batch_size): + """Concatenate noise batches, then chop to correct size""" + noise_batch, noise_lens = self._load_noise_batch() + + # Expand + while len(noise_batch) < batch_size: + noise_batch = paddle.concat((noise_batch, noise_batch)) + noise_lens = paddle.concat((noise_lens, noise_lens)) + + # Contract + if len(noise_batch) > batch_size: + noise_batch = noise_batch[:batch_size] + noise_lens = noise_lens[:batch_size] + + return noise_batch, noise_lens + + def _load_noise_batch(self): + """Load a batch of noises, restarting iteration if necessary.""" + try: + batch = next(self.noise_data) + except StopIteration: + self.noise_data = iter(self.noise_dataloader) + batch = next(self.noise_data) + + noises, lens = batch['feats'], batch['lengths'] + return noises, lens + + +class AddReverb(nn.Layer): + def __init__( + self, + rir_dataset, + reverb_prob=1.0, + rir_scale_factor=1.0, + num_workers=0, ): + super(AddReverb, self).__init__() + self.rir_dataset = rir_dataset + self.reverb_prob = reverb_prob + self.rir_scale_factor = rir_scale_factor + + # Create rir data loader. + def rir_collate_fn(batch): + def pad(x, target_length, mode='constant', **kwargs): + x = np.asarray(x) + w = target_length - x.shape[0] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' + return np.pad(x, [0, w], mode=mode, **kwargs) + + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[0] for item in batch]) + waveforms = list( + map(lambda x: pad(x, lengths.max().item()), + [item['feat'] for item in batch])) + waveforms = np.stack(waveforms) + return {'ids': ids, 'feats': waveforms, 'lengths': lengths} + + self.rir_dataloader = paddle.io.DataLoader( + self.rir_dataset, + collate_fn=rir_collate_fn, + num_workers=num_workers, + shuffle=True, + return_list=True, ) + + self.rir_data = iter(self.rir_dataloader) + + def forward(self, waveforms, lengths=None): + """ + Arguments + --------- + waveforms : tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : tensor + Shape should be a single dimension, `[batch]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Don't add reverb (return early) 1-`reverb_prob` portion of the time + if paddle.rand([1]) > self.reverb_prob: + return waveforms.clone() + + # Add channels dimension if necessary + channel_added = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(-1) + channel_added = True + + # Load and prepare RIR + rir_waveform = self._load_rir() + + # Compress or dilate RIR + if self.rir_scale_factor != 1: + rir_waveform = F.interpolate( + rir_waveform.transpose([0, 2, 1]), + scale_factor=self.rir_scale_factor, + mode="linear", + align_corners=False, + data_format='NCW', ) + # (N, C, L) -> (N, L, C) + rir_waveform = rir_waveform.transpose([0, 2, 1]) + + rev_waveform = reverberate( + waveforms, + rir_waveform, + self.rir_dataset.sample_rate, + rescale_amp="avg") + + # Remove channels dimension if added + if channel_added: + return rev_waveform.squeeze(-1) + + return rev_waveform + + def _load_rir(self): + try: + batch = next(self.rir_data) + except StopIteration: + self.rir_data = iter(self.rir_dataloader) + batch = next(self.rir_data) + + rir_waveform = batch['feats'] + + # Make sure RIR has correct channels + if len(rir_waveform.shape) == 2: + rir_waveform = rir_waveform.unsqueeze(-1) + + return rir_waveform + + +class AddBabble(nn.Layer): + def __init__( + self, + speaker_count=3, + snr_low=0, + snr_high=0, + mix_prob=1, ): + super(AddBabble, self).__init__() + self.speaker_count = speaker_count + self.snr_low = snr_low + self.snr_high = snr_high + self.mix_prob = mix_prob + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + babbled_waveform = waveforms.clone() + lengths = (lengths * waveforms.shape[1]).unsqueeze(1) + batch_size = len(waveforms) + + # Don't mix (return early) 1-`mix_prob` portion of the batches + if paddle.rand([1]) > self.mix_prob: + return babbled_waveform + + # Pick an SNR and use it to compute the mixture amplitude factors + clean_amplitude = compute_amplitude(waveforms, lengths) + SNR = paddle.rand((batch_size, 1)) + SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low + noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1) + new_noise_amplitude = noise_amplitude_factor * clean_amplitude + + # Scale clean signal appropriately + babbled_waveform *= 1 - noise_amplitude_factor + + # For each speaker in the mixture, roll and add + babble_waveform = waveforms.roll((1, ), axis=0) + babble_len = lengths.roll((1, ), axis=0) + for i in range(1, self.speaker_count): + babble_waveform += waveforms.roll((1 + i, ), axis=0) + babble_len = paddle.concat( + [babble_len, babble_len.roll((1, ), axis=0)], axis=-1).max( + axis=-1, keepdim=True) + + # Rescale and add to mixture + babble_amplitude = compute_amplitude(babble_waveform, babble_len) + babble_waveform *= new_noise_amplitude / (babble_amplitude + 1e-14) + babbled_waveform += babble_waveform + + return babbled_waveform + + +class TimeDomainSpecAugment(nn.Layer): + def __init__( + self, + perturb_prob=1.0, + drop_freq_prob=1.0, + drop_chunk_prob=1.0, + speeds=[95, 100, 105], + sample_rate=16000, + drop_freq_count_low=0, + drop_freq_count_high=3, + drop_chunk_count_low=0, + drop_chunk_count_high=5, + drop_chunk_length_low=1000, + drop_chunk_length_high=2000, + drop_chunk_noise_factor=0, ): + super(TimeDomainSpecAugment, self).__init__() + self.speed_perturb = SpeedPerturb( + perturb_prob=perturb_prob, + orig_freq=sample_rate, + speeds=speeds, ) + self.drop_freq = DropFreq( + drop_prob=drop_freq_prob, + drop_count_low=drop_freq_count_low, + drop_count_high=drop_freq_count_high, ) + self.drop_chunk = DropChunk( + drop_prob=drop_chunk_prob, + drop_count_low=drop_chunk_count_low, + drop_count_high=drop_chunk_count_high, + drop_length_low=drop_chunk_length_low, + drop_length_high=drop_chunk_length_high, + noise_factor=drop_chunk_noise_factor, ) + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + with paddle.no_grad(): + # Augmentation + waveforms = self.speed_perturb(waveforms) + waveforms = self.drop_freq(waveforms) + waveforms = self.drop_chunk(waveforms, lengths) + + return waveforms + + +class EnvCorrupt(nn.Layer): + def __init__( + self, + reverb_prob=1.0, + babble_prob=1.0, + noise_prob=1.0, + rir_dataset=None, + noise_dataset=None, + num_workers=0, + babble_speaker_count=0, + babble_snr_low=0, + babble_snr_high=0, + noise_snr_low=0, + noise_snr_high=0, + rir_scale_factor=1.0, ): + super(EnvCorrupt, self).__init__() + + # Initialize corrupters + if rir_dataset is not None and reverb_prob > 0.0: + self.add_reverb = AddReverb( + rir_dataset=rir_dataset, + num_workers=num_workers, + reverb_prob=reverb_prob, + rir_scale_factor=rir_scale_factor, ) + + if babble_speaker_count > 0 and babble_prob > 0.0: + self.add_babble = AddBabble( + speaker_count=babble_speaker_count, + snr_low=babble_snr_low, + snr_high=babble_snr_high, + mix_prob=babble_prob, ) + + if noise_dataset is not None and noise_prob > 0.0: + self.add_noise = AddNoise( + noise_dataset=noise_dataset, + num_workers=num_workers, + snr_low=noise_snr_low, + snr_high=noise_snr_high, + mix_prob=noise_prob, ) + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Augmentation + with paddle.no_grad(): + if hasattr(self, "add_reverb"): + try: + waveforms = self.add_reverb(waveforms, lengths) + except Exception: + pass + if hasattr(self, "add_babble"): + waveforms = self.add_babble(waveforms, lengths) + if hasattr(self, "add_noise"): + waveforms = self.add_noise(waveforms, lengths) + + return waveforms + + +def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]: + """build augment pipeline + Note: this pipeline cannot be used in the paddle.DataLoader + + Returns: + List[paddle.nn.Layer]: all augment process + """ + logger.info("start to build the augment pipeline") + noise_dataset = OpenRIRNoise('noise', target_dir=target_dir) + rir_dataset = OpenRIRNoise('rir') + + wavedrop = TimeDomainSpecAugment( + sample_rate=16000, + speeds=[100], ) + speed_perturb = TimeDomainSpecAugment( + sample_rate=16000, + speeds=[95, 100, 105], ) + add_noise = EnvCorrupt( + noise_dataset=noise_dataset, + reverb_prob=0.0, + noise_prob=1.0, + noise_snr_low=0, + noise_snr_high=15, + rir_scale_factor=1.0, ) + add_rev = EnvCorrupt( + rir_dataset=rir_dataset, + reverb_prob=1.0, + noise_prob=0.0, + rir_scale_factor=1.0, ) + add_rev_noise = EnvCorrupt( + noise_dataset=noise_dataset, + rir_dataset=rir_dataset, + reverb_prob=1.0, + noise_prob=1.0, + noise_snr_low=0, + noise_snr_high=15, + rir_scale_factor=1.0, ) + + return [wavedrop, speed_perturb, add_noise, add_rev, add_rev_noise] + + +def waveform_augment(waveforms: paddle.Tensor, + augment_pipeline: List[paddle.nn.Layer]) -> paddle.Tensor: + """process the augment pipeline and return all the waveforms + + Args: + waveforms (paddle.Tensor): _description_ + augment_pipeline (List[paddle.nn.Layer]): _description_ + + Returns: + paddle.Tensor: _description_ + """ + waveforms_aug_list = [waveforms] + for aug in augment_pipeline: + waveforms_aug = aug(waveforms) # (N, L) + if waveforms_aug.shape[1] >= waveforms.shape[1]: + # Trunc + waveforms_aug = waveforms_aug[:, :waveforms.shape[1]] + else: + # Pad + lengths_to_pad = waveforms.shape[1] - waveforms_aug.shape[1] + waveforms_aug = F.pad( + waveforms_aug.unsqueeze(-1), [0, lengths_to_pad], + data_format='NLC').squeeze(-1) + waveforms_aug_list.append(waveforms_aug) + + return paddle.concat(waveforms_aug_list, axis=0) diff --git a/paddlespeech/vector/io/signal_processing.py b/paddlespeech/vector/io/signal_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..a61bf554960e4a71c77e16d9021b125a26f28551 --- /dev/null +++ b/paddlespeech/vector/io/signal_processing.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle + +# TODO: Complete type-hint and doc string. + + +def blackman_window(win_len, dtype=np.float32): + arcs = np.pi * np.arange(win_len) / float(win_len) + win = np.asarray( + [0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs], + dtype=dtype) + return paddle.to_tensor(win) + + +def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"): + if len(waveforms.shape) == 1: + waveforms = waveforms.unsqueeze(0) + + assert amp_type in ["avg", "peak"] + assert scale in ["linear", "dB"] + + if amp_type == "avg": + if lengths is None: + out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True) + else: + wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True) + out = wav_sum / lengths + elif amp_type == "peak": + out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True) + else: + raise NotImplementedError + + if scale == "linear": + return out + elif scale == "dB": + return paddle.clip(20 * paddle.log10(out), min=-80) + else: + raise NotImplementedError + + +def dB_to_amplitude(SNR): + return 10**(SNR / 20) + + +def convolve1d( + waveform, + kernel, + padding=0, + pad_type="constant", + stride=1, + groups=1, ): + if len(waveform.shape) != 3: + raise ValueError("Convolve1D expects a 3-dimensional tensor") + + # Padding can be a tuple (left_pad, right_pad) or an int + if isinstance(padding, list): + waveform = paddle.nn.functional.pad( + x=waveform, + pad=padding, + mode=pad_type, + data_format='NLC', ) + + # Move time dimension last, which pad and fft and conv expect. + # (N, L, C) -> (N, C, L) + waveform = waveform.transpose([0, 2, 1]) + kernel = kernel.transpose([0, 2, 1]) + + convolved = paddle.nn.functional.conv1d( + x=waveform, + weight=kernel, + stride=stride, + groups=groups, + padding=padding if not isinstance(padding, list) else 0, ) + + # Return time dimension to the second dimension. + return convolved.transpose([0, 2, 1]) + + +def notch_filter(notch_freq, filter_width=101, notch_width=0.05): + # Check inputs + assert 0 < notch_freq <= 1 + assert filter_width % 2 != 0 + pad = filter_width // 2 + inputs = paddle.arange(filter_width, dtype='float32') - pad + + # Avoid frequencies that are too low + notch_freq += notch_width + + # Define sinc function, avoiding division by zero + def sinc(x): + def _sinc(x): + return paddle.sin(x) / x + + # The zero is at the middle index + res = paddle.concat( + [_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])]) + return res + + # Compute a low-pass filter with cutoff frequency notch_freq. + hlpf = sinc(3 * (notch_freq - notch_width) * inputs) + # import torch + # hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy()) + hlpf *= blackman_window(filter_width) + hlpf /= paddle.sum(hlpf) + + # Compute a high-pass filter with cutoff frequency notch_freq. + hhpf = sinc(3 * (notch_freq + notch_width) * inputs) + # hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy()) + hhpf *= blackman_window(filter_width) + hhpf /= -paddle.sum(hhpf) + hhpf[pad] += 1 + + # Adding filters creates notch filter + return (hlpf + hhpf).reshape([1, -1, 1]) + + +def reverberate(waveforms, + rir_waveform, + sample_rate, + impulse_duration=0.3, + rescale_amp="avg"): + orig_shape = waveforms.shape + + if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3: + raise NotImplementedError + + # if inputs are mono tensors we reshape to 1, samples + if len(waveforms.shape) == 1: + waveforms = waveforms.unsqueeze(0).unsqueeze(-1) + elif len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(-1) + + if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor ! + rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1) + elif len(rir_waveform.shape) == 2: + rir_waveform = rir_waveform.unsqueeze(-1) + + # Compute the average amplitude of the clean + orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1], + rescale_amp) + + # Compute index of the direct signal, so we can preserve alignment + impulse_index_start = rir_waveform.abs().argmax(axis=1).item() + impulse_index_end = min( + impulse_index_start + int(sample_rate * impulse_duration), + rir_waveform.shape[1]) + rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :] + rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2) + rir_waveform = paddle.flip(rir_waveform, [1]) + + waveforms = convolve1d( + waveform=waveforms, + kernel=rir_waveform, + padding=[rir_waveform.shape[1] - 1, 0], ) + + # Rescale to the peak amplitude of the clean waveform + waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude, + rescale_amp) + + if len(orig_shape) == 1: + waveforms = waveforms.squeeze(0).squeeze(-1) + if len(orig_shape) == 2: + waveforms = waveforms.squeeze(-1) + + return waveforms + + +def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"): + assert amp_type in ["peak", "avg"] + assert scale in ["linear", "dB"] + + batch_added = False + if len(waveforms.shape) == 1: + batch_added = True + waveforms = waveforms.unsqueeze(0) + + waveforms = normalize(waveforms, lengths, amp_type) + + if scale == "linear": + out = target_lvl * waveforms + elif scale == "dB": + out = dB_to_amplitude(target_lvl) * waveforms + + else: + raise NotImplementedError("Invalid scale, choose between dB and linear") + + if batch_added: + out = out.squeeze(0) + + return out + + +def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14): + assert amp_type in ["avg", "peak"] + + batch_added = False + if len(waveforms.shape) == 1: + batch_added = True + waveforms = waveforms.unsqueeze(0) + + den = compute_amplitude(waveforms, lengths, amp_type) + eps + if batch_added: + waveforms = waveforms.squeeze(0) + return waveforms / den diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py index e493b8004e2255135a070483f01ca709698fc81c..4c960e117f5d952a214631d58b4c9a023af33c26 100644 --- a/paddlespeech/vector/models/ecapa_tdnn.py +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -19,6 +19,16 @@ import paddle.nn.functional as F def length_to_mask(length, max_len=None, dtype=None): + """_summary_ + + Args: + length (_type_): _description_ + max_len (_type_, optional): _description_. Defaults to None. + dtype (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ assert len(length.shape) == 1 if max_len is None: @@ -47,6 +57,19 @@ class Conv1d(nn.Layer): groups=1, bias=True, padding_mode="reflect", ): + """_summary_ + + Args: + in_channels (_type_): _description_ + out_channels (_type_): _description_ + kernel_size (_type_): _description_ + stride (int, optional): _description_. Defaults to 1. + padding (str, optional): _description_. Defaults to "same". + dilation (int, optional): _description_. Defaults to 1. + groups (int, optional): _description_. Defaults to 1. + bias (bool, optional): _description_. Defaults to True. + padding_mode (str, optional): _description_. Defaults to "reflect". + """ super().__init__() self.kernel_size = kernel_size @@ -66,6 +89,17 @@ class Conv1d(nn.Layer): bias_attr=bias, ) def forward(self, x): + """_summary_ + + Args: + x (_type_): _description_ + + Raises: + ValueError: _description_ + + Returns: + _type_: _description_ + """ if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) @@ -75,6 +109,17 @@ class Conv1d(nn.Layer): return self.conv(x) def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + """_summary_ + + Args: + x (_type_): _description_ + kernel_size (int): _description_ + dilation (int): _description_ + stride (int): _description_ + + Returns: + _type_: _description_ + """ L_in = x.shape[-1] # Detecting input shape padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding @@ -88,6 +133,17 @@ class Conv1d(nn.Layer): stride: int, kernel_size: int, dilation: int): + """_summary_ + + Args: + L_in (int): _description_ + stride (int): _description_ + kernel_size (int): _description_ + dilation (int): _description_ + + Returns: + _type_: _description_ + """ if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation @@ -134,6 +190,15 @@ class TDNNBlock(nn.Layer): kernel_size, dilation, activation=nn.ReLU, ): + """Implementation of TDNN network + + Args: + in_channels (int): input channels or input embedding dimensions + out_channels (int): output channels or output embedding dimensions + kernel_size (int): the kernel size of the TDNN network block + dilation (int): the dilation of the TDNN network block + activation (paddle class, optional): the activation layers. Defaults to nn.ReLU. + """ super().__init__() self.conv = Conv1d( in_channels=in_channels, @@ -149,6 +214,15 @@ class TDNNBlock(nn.Layer): class Res2NetBlock(nn.Layer): def __init__(self, in_channels, out_channels, scale=8, dilation=1): + """Implementation of Res2Net Block with dilation + The paper is refered as "Res2Net: A New Multi-scale Backbone Architecture", + whose url is https://arxiv.org/abs/1904.01169 + Args: + in_channels (int): input channels or input dimensions + out_channels (int): output channels or output dimensions + scale (int, optional): _description_. Defaults to 8. + dilation (int, optional): _description_. Defaults to 1. + """ super().__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 @@ -179,6 +253,14 @@ class Res2NetBlock(nn.Layer): class SEBlock(nn.Layer): def __init__(self, in_channels, se_channels, out_channels): + """Implementation of SEBlock + The paper is refered as "Squeeze-and-Excitation Networks" + whose url is https://arxiv.org/abs/1709.01507 + Args: + in_channels (int): input channels or input data dimensions + se_channels (_type_): _description_ + out_channels (int): output channels or output data dimensions + """ super().__init__() self.conv1 = Conv1d( @@ -275,6 +357,17 @@ class SERes2NetBlock(nn.Layer): kernel_size=1, dilation=1, activation=nn.ReLU, ): + """Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model + + Args: + in_channels (int): input channels or input data dimensions + out_channels (_type_): _description_ + res2net_scale (int, optional): _description_. Defaults to 8. + se_channels (int, optional): _description_. Defaults to 128. + kernel_size (int, optional): _description_. Defaults to 1. + dilation (int, optional): _description_. Defaults to 1. + activation (_type_, optional): _description_. Defaults to nn.ReLU. + """ super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( diff --git a/paddlespeech/vector/training/seeding.py b/paddlespeech/vector/training/seeding.py new file mode 100644 index 0000000000000000000000000000000000000000..0778a27d61943ad63095a72a045b8ea52d8602d6 --- /dev/null +++ b/paddlespeech/vector/training/seeding.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() +import random + +import numpy as np +import paddle + + +def seed_everything(seed: int): + """Seed paddle, random and np.random to help reproductivity.""" + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + logger.info(f"Set the seed of paddle, random, np.random to {seed}.")