提交 2d89c80e 编写于 作者: X xiongxinlei

add waveform augment pipeline, test=doc

上级 ac4967e2
......@@ -23,9 +23,13 @@ from paddle.io import DataLoader
from tqdm import tqdm
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.metrics import compute_eer
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
......@@ -67,9 +71,19 @@ def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
return {'ids': ids, 'feats': feats, 'lengths': lengths}
# feat configuration
cpu_feat_conf = {
'n_mels': 80,
'window_size': 400, #ms
'hop_length': 160, #ms
}
def main(args):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(args.seed)
# stage1: build the dnn backbone model network
##"channels": [1024, 1024, 1024, 1024, 3072],
......@@ -95,19 +109,18 @@ def main(args):
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
print(f'Checkpoint loaded from {args.load_checkpoint}')
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enrol_ds = VoxCeleb1(
subset='enrol',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
**cpu_feat_conf)
enrol_sampler = BatchSampler(
enrol_ds, batch_size=args.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
shuffle=False) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enrol_ds,
batch_sampler=enrol_sampler,
collate_fn=lambda x: feature_normalize(
......@@ -117,14 +130,13 @@ def main(args):
test_ds = VoxCeleb1(
subset='test',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=80,
window_size=400,
hop_length=160)
**cpu_feat_conf)
test_sampler = BatchSampler(
test_ds, batch_size=args.batch_size, shuffle=True)
test_ds, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_ds,
batch_sampler=test_sampler,
collate_fn=lambda x: feature_normalize(
......@@ -136,10 +148,10 @@ def main(args):
# stage7: global embedding norm to imporve the performance
if args.global_embedding_norm:
embedding_mean = None
embedding_std = None
mean_norm = args.embedding_mean_norm
std_norm = args.embedding_std_norm
global_embedding_mean = None
global_embedding_std = None
mean_norm_flag = args.embedding_mean_norm
std_norm_flag = args.embedding_std_norm
batch_count = 0
# stage8: Compute embeddings of audios in enrol and test dataset from model.
......@@ -147,7 +159,7 @@ def main(args):
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
print(
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
)
with paddle.no_grad():
......@@ -162,20 +174,24 @@ def main(args):
# Global embedding normalization.
if args.global_embedding_norm:
batch_count += 1
mean = embeddings.mean(axis=0) if mean_norm else 0
std = embeddings.std(axis=0) if std_norm else 1
current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0
current_std = embeddings.std(
axis=0) if std_norm_flag else 1
# Update global mean and std.
if embedding_mean is None and embedding_std is None:
embedding_mean, embedding_std = mean, std
if global_embedding_mean is None and global_embedding_std is None:
global_embedding_mean, global_embedding_std = current_mean, current_std
else:
weight = 1 / batch_count # Weight decay by batches.
embedding_mean = (1 - weight
) * embedding_mean + weight * mean
embedding_std = (1 - weight
) * embedding_std + weight * std
global_embedding_mean = (
1 - weight
) * global_embedding_mean + weight * current_mean
global_embedding_std = (
1 - weight
) * global_embedding_std + weight * current_std
# Apply global embedding normalization.
embeddings = (
embeddings - embedding_mean) / embedding_std
embeddings = (embeddings - global_embedding_mean
) / global_embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
......@@ -198,7 +214,7 @@ def main(args):
]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
print(
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
)
......@@ -210,10 +226,18 @@ if __name__ == "__main__":
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--seed",
default=0,
type=int,
help="random seed for paddle, numpy and python random package")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--batch-size",
type=int,
default=16,
help="Total examples' number in batch for training.")
help="Total examples' number in batch for extract the embedding.")
parser.add_argument("--num-workers",
type=int,
default=0,
......
......@@ -22,6 +22,9 @@ from paddle.io import DistributedBatchSampler
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddleaudio.features.core import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
......@@ -29,8 +32,11 @@ from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.modules.lr import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer
logger = Log(__name__).getlog()
# feat configuration
cpu_feat_conf = {
'n_mels': 80,
......@@ -47,12 +53,19 @@ def main(args):
paddle.distributed.init_parallel_env()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
# set the random seed, it is a must for multiprocess training
seed_everything(args.seed)
# stage2: data prepare
# note: some cmd must do in rank==0
# stage2: data prepare, such vox1 and vox2 data, and augment data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_ds = VoxCeleb1('train', target_dir=args.data_dir)
dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
else:
augment_pipeline = []
# stage3: build the dnn backbone model network
#"channels": [1024, 1024, 1024, 1024, 3072],
model_conf = {
......@@ -83,7 +96,7 @@ def main(args):
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0
if args.load_checkpoint:
print("load the check point")
logger.info("load the check point")
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
try:
......@@ -97,14 +110,14 @@ def main(args):
os.path.join(args.load_checkpoint, 'model.pdopt'))
optimizer.set_state_dict(state_dict)
if local_rank == 0:
print(f'Checkpoint loaded from {args.load_checkpoint}')
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
except FileExistsError:
if local_rank == 0:
print('Train from scratch.')
logger.info('Train from scratch.')
try:
start_epoch = int(args.load_checkpoint[-1])
print(f'Restore training from epoch {start_epoch}.')
logger.info(f'Restore training from epoch {start_epoch}.')
except ValueError:
pass
......@@ -137,7 +150,10 @@ def main(args):
waveforms, labels = batch['waveforms'], batch['labels']
# stage 9-2: audio sample augment method, which is done on the audio sample point
# todo
if len(augment_pipeline) != 0:
waveforms = waveform_augment(waveforms, augment_pipeline)
labels = paddle.concat(
[labels for i in range(len(augment_pipeline) + 1)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = []
......@@ -185,7 +201,7 @@ def main(args):
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
print(print_msg)
logger.info(print_msg)
avg_loss = 0
num_corrects = 0
......@@ -217,7 +233,7 @@ def main(args):
num_samples = 0
# stage 9-13: evaluation the valid dataset batch data
print('Evaluate on validation dataset')
logger.info('Evaluate on validation dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(dev_loader):
waveforms, labels = batch['waveforms'], batch['labels']
......@@ -238,12 +254,12 @@ def main(args):
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
print(print_msg)
logger.info(print_msg)
# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
print('Saving model checkpoint to {}'.format(save_dir))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
......@@ -260,6 +276,10 @@ if __name__ == "__main__":
choices=['cpu', 'gpu'],
default="cpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--seed",
default=0,
type=int,
help="random seed for paddle, numpy and python random package")
parser.add_argument("--data-dir",
default="./data/",
type=str,
......@@ -295,6 +315,10 @@ if __name__ == "__main__":
type=str,
default='./checkpoint',
help="Directory to save model checkpoints.")
parser.add_argument("--augment",
action="store_true",
default=False,
help="Apply audio augments.")
args = parser.parse_args()
# yapf: enable
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import glob
import os
import random
from typing import Dict
from typing import List
from typing import Tuple
from paddle.io import Dataset
from tqdm import tqdm
from paddleaudio.backends import load as load_audio
from paddleaudio.backends import save_wav
from paddleaudio.datasets.dataset import feat_funcs
from paddleaudio.utils import DATA_HOME
from paddleaudio.utils import decompress
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.download import download_and_decompress
logger = Log(__name__).getlog()
__all__ = ['OpenRIRNoise']
class OpenRIRNoise(Dataset):
archieves = [
{
'url': 'http://www.openslr.org/resources/28/rirs_noises.zip',
'md5': 'e6f48e257286e05de56413b4779d8ffb',
},
]
sample_rate = 16000
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
base_path = os.path.join(DATA_HOME, 'open_rir_noise')
wav_path = os.path.join(base_path, 'RIRS_NOISES')
csv_path = os.path.join(base_path, 'csv')
subsets = ['rir', 'noise']
def __init__(self,
subset: str='rir',
feat_type: str='raw',
target_dir=None,
random_chunk: bool=True,
chunk_duration: float=3.0,
seed: int=0,
**kwargs):
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.csv_path = os.path.join(target_dir, "open_rir_noise",
"csv") if target_dir else self.csv_path
self._data = self._get_data()
super(OpenRIRNoise, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
logger.info(f"rirs noises base path: {self.base_path}")
if not os.path.isdir(self.base_path):
download_and_decompress(
self.archieves, self.base_path, decompress=True)
else:
logger.info(
f"{self.base_path} already exists, we will not download and decompress again"
)
# Data preparation.
logger.info(f"prepare the csv to {self.csv_path}")
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav = line.strip().split(',')
data.append(self.meta_info(audio_id, float(duration), wav))
random.shuffle(data)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0]
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for idx, chunk in enumerate(uniq_chunks_list):
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
new_wav_file = os.path.join(self.base_path,
audio_id + f'_chunk_{idx+1:02}.wav')
save_wav(waveform[start_sample:end_sample], sr, new_wav_file)
# id, duration, new_wav
ret.append([chunk, self.chunk_duration, new_wav_file])
else: # Keep whole audio.
ret.append([audio_id, audio_duration, wav_file])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
logger.info(f'Generating csv: {output_file}')
header = ["id", "duration", "wav"]
infos = list(
tqdm(
map(self._get_audio_info, wav_files, [split_chunks] * len(
wav_files)),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises",
"rir_list")
rir_files = []
with open(rir_list, 'r') as f:
for line in f.readlines():
rir_file = line.strip().split(' ')[-1]
rir_files.append(os.path.join(self.base_path, rir_file))
noise_list = os.path.join(self.wav_path, "pointsource_noises",
"noise_list")
noise_files = []
with open(noise_list, 'r') as f:
for line in f.readlines():
noise_file = line.strip().split(' ')[-1]
noise_files.append(os.path.join(self.base_path, noise_file))
self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv'))
self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv'))
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
......@@ -29,9 +29,12 @@ from paddleaudio.datasets.dataset import feat_funcs
from paddleaudio.utils import DATA_HOME
from paddleaudio.utils import decompress
from paddleaudio.utils import download_and_decompress
from paddlespeech.s2t.utils.log import Log
from utils.utility import download
from utils.utility import unpack
logger = Log(__name__).getlog()
__all__ = ['VoxCeleb1']
......@@ -121,9 +124,9 @@ class VoxCeleb1(Dataset):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print("wav base path: {}".format(self.wav_path))
logger.info(f"wav base path: {self.wav_path}")
if not os.path.isdir(self.wav_path):
print("start to download the voxceleb1 dataset")
logger.info(f"start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev,
self.base_path,
......@@ -135,7 +138,7 @@ class VoxCeleb1(Dataset):
# Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
print(f'Concatenating all parts to: {dev_zipfile}')
logger.info(f'Concatenating all parts to: {dev_zipfile}')
os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
)
......@@ -154,6 +157,9 @@ class VoxCeleb1(Dataset):
self.prepare_data()
data = []
logger.info(
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
)
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
......@@ -246,7 +252,7 @@ class VoxCeleb1(Dataset):
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
logger.info(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
with Pool(64) as p:
......@@ -269,7 +275,7 @@ class VoxCeleb1(Dataset):
def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set.
print("start to prepare the data csv file")
logger.info("start to prepare the data csv file")
enrol_files = set()
test_files = set()
# get the enroll and test audio file path
......@@ -299,7 +305,7 @@ class VoxCeleb1(Dataset):
speakers.add(spk)
audio_files.append(file)
print("start to generate the {}".format(
logger.info("start to generate the {}".format(
os.path.join(self.meta_path, 'spk_id2label.txt')))
# encode the train and dev speakers label to spk_id2label.txt
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
......
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
# TODO: Complete type-hint and doc string.
def blackman_window(win_len, dtype=np.float32):
arcs = np.pi * np.arange(win_len) / float(win_len)
win = np.asarray(
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs],
dtype=dtype)
return paddle.to_tensor(win)
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]
if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
else:
raise NotImplementedError
if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80)
else:
raise NotImplementedError
def dB_to_amplitude(SNR):
return 10**(SNR / 20)
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1, ):
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, list):
waveform = paddle.nn.functional.pad(
x=waveform,
pad=padding,
mode=pad_type,
data_format='NLC', )
# Move time dimension last, which pad and fft and conv expect.
# (N, L, C) -> (N, C, L)
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])
convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, list) else 0, )
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width, dtype='float32') - pad
# Avoid frequencies that are too low
notch_freq += notch_width
# Define sinc function, avoiding division by zero
def sinc(x):
def _sinc(x):
return paddle.sin(x) / x
# The zero is at the middle index
res = paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
return res
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
# import torch
# hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
# hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).reshape([1, -1, 1])
def reverberate(waveforms,
rir_waveform,
sample_rate,
impulse_duration=0.3,
rescale_amp="avg"):
orig_shape = waveforms.shape
if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3:
raise NotImplementedError
# if inputs are mono tensors we reshape to 1, samples
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0).unsqueeze(-1)
elif len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(-1)
if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor !
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1)
elif len(rir_waveform.shape) == 2:
rir_waveform = rir_waveform.unsqueeze(-1)
# Compute the average amplitude of the clean
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1],
rescale_amp)
# Compute index of the direct signal, so we can preserve alignment
impulse_index_start = rir_waveform.abs().argmax(axis=1).item()
impulse_index_end = min(
impulse_index_start + int(sample_rate * impulse_duration),
rir_waveform.shape[1])
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :]
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2)
rir_waveform = paddle.flip(rir_waveform, [1])
waveforms = convolve1d(
waveform=waveforms,
kernel=rir_waveform,
padding=[rir_waveform.shape[1] - 1, 0], )
# Rescale to the peak amplitude of the clean waveform
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude,
rescale_amp)
if len(orig_shape) == 1:
waveforms = waveforms.squeeze(0).squeeze(-1)
if len(orig_shape) == 2:
waveforms = waveforms.squeeze(-1)
return waveforms
def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
assert amp_type in ["peak", "avg"]
assert scale in ["linear", "dB"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
waveforms = normalize(waveforms, lengths, amp_type)
if scale == "linear":
out = target_lvl * waveforms
elif scale == "dB":
out = dB_to_amplitude(target_lvl) * waveforms
else:
raise NotImplementedError("Invalid scale, choose between dB and linear")
if batch_added:
out = out.squeeze(0)
return out
def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
assert amp_type in ["avg", "peak"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
den = compute_amplitude(waveforms, lengths, amp_type) + eps
if batch_added:
waveforms = waveforms.squeeze(0)
return waveforms / den
......@@ -19,6 +19,16 @@ import paddle.nn.functional as F
def length_to_mask(length, max_len=None, dtype=None):
"""_summary_
Args:
length (_type_): _description_
max_len (_type_, optional): _description_. Defaults to None.
dtype (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
assert len(length.shape) == 1
if max_len is None:
......@@ -47,6 +57,19 @@ class Conv1d(nn.Layer):
groups=1,
bias=True,
padding_mode="reflect", ):
"""_summary_
Args:
in_channels (_type_): _description_
out_channels (_type_): _description_
kernel_size (_type_): _description_
stride (int, optional): _description_. Defaults to 1.
padding (str, optional): _description_. Defaults to "same".
dilation (int, optional): _description_. Defaults to 1.
groups (int, optional): _description_. Defaults to 1.
bias (bool, optional): _description_. Defaults to True.
padding_mode (str, optional): _description_. Defaults to "reflect".
"""
super().__init__()
self.kernel_size = kernel_size
......@@ -66,6 +89,17 @@ class Conv1d(nn.Layer):
bias_attr=bias, )
def forward(self, x):
"""_summary_
Args:
x (_type_): _description_
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
if self.padding == "same":
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
......@@ -75,6 +109,17 @@ class Conv1d(nn.Layer):
return self.conv(x)
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
"""_summary_
Args:
x (_type_): _description_
kernel_size (int): _description_
dilation (int): _description_
stride (int): _description_
Returns:
_type_: _description_
"""
L_in = x.shape[-1] # Detecting input shape
padding = self._get_padding_elem(L_in, stride, kernel_size,
dilation) # Time padding
......@@ -88,6 +133,17 @@ class Conv1d(nn.Layer):
stride: int,
kernel_size: int,
dilation: int):
"""_summary_
Args:
L_in (int): _description_
stride (int): _description_
kernel_size (int): _description_
dilation (int): _description_
Returns:
_type_: _description_
"""
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
......@@ -134,6 +190,15 @@ class TDNNBlock(nn.Layer):
kernel_size,
dilation,
activation=nn.ReLU, ):
"""Implementation of TDNN network
Args:
in_channels (int): input channels or input embedding dimensions
out_channels (int): output channels or output embedding dimensions
kernel_size (int): the kernel size of the TDNN network block
dilation (int): the dilation of the TDNN network block
activation (paddle class, optional): the activation layers. Defaults to nn.ReLU.
"""
super().__init__()
self.conv = Conv1d(
in_channels=in_channels,
......@@ -149,6 +214,15 @@ class TDNNBlock(nn.Layer):
class Res2NetBlock(nn.Layer):
def __init__(self, in_channels, out_channels, scale=8, dilation=1):
"""Implementation of Res2Net Block with dilation
The paper is refered as "Res2Net: A New Multi-scale Backbone Architecture",
whose url is https://arxiv.org/abs/1904.01169
Args:
in_channels (int): input channels or input dimensions
out_channels (int): output channels or output dimensions
scale (int, optional): _description_. Defaults to 8.
dilation (int, optional): _description_. Defaults to 1.
"""
super().__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
......@@ -179,6 +253,14 @@ class Res2NetBlock(nn.Layer):
class SEBlock(nn.Layer):
def __init__(self, in_channels, se_channels, out_channels):
"""Implementation of SEBlock
The paper is refered as "Squeeze-and-Excitation Networks"
whose url is https://arxiv.org/abs/1709.01507
Args:
in_channels (int): input channels or input data dimensions
se_channels (_type_): _description_
out_channels (int): output channels or output data dimensions
"""
super().__init__()
self.conv1 = Conv1d(
......@@ -275,6 +357,17 @@ class SERes2NetBlock(nn.Layer):
kernel_size=1,
dilation=1,
activation=nn.ReLU, ):
"""Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model
Args:
in_channels (int): input channels or input data dimensions
out_channels (_type_): _description_
res2net_scale (int, optional): _description_. Defaults to 8.
se_channels (int, optional): _description_. Defaults to 128.
kernel_size (int, optional): _description_. Defaults to 1.
dilation (int, optional): _description_. Defaults to 1.
activation (_type_, optional): _description_. Defaults to nn.ReLU.
"""
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
import random
import numpy as np
import paddle
def seed_everything(seed: int):
"""Seed paddle, random and np.random to help reproductivity."""
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
logger.info(f"Set the seed of paddle, random, np.random to {seed}.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册