From 6af2bc3d5badbaa865db9d2e5b371ea2eecb9a0d Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 3 Mar 2022 16:49:46 +0800 Subject: [PATCH] add sid loss wraper for voxceleb, test=doc --- examples/voxceleb/sv0/local/train.py | 39 +++++++++++++++- paddleaudio/utils/download.py | 26 ++++++++--- paddlespeech/vector/layers/loss.py | 70 ++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 7 deletions(-) create mode 100644 paddlespeech/vector/layers/loss.py diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py index 8dea5fff..1d9a78f9 100644 --- a/examples/voxceleb/sv0/local/train.py +++ b/examples/voxceleb/sv0/local/train.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import argparse import paddle @@ -19,7 +20,7 @@ from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddlespeech.vector.layers.lr import CyclicLRScheduler from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.training.sid_model import SpeakerIdetification - +from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper def main(args): # stage0: set the training device, cpu or gpu @@ -33,6 +34,7 @@ def main(args): # stage2: data prepare # note: some cmd must do in rank==0 train_ds = VoxCeleb1('train', target_dir=args.data_dir) + dev_ds = VoxCeleb1('dev', target_dir=args.data_dir) # stage3: build the dnn backbone model network model_conf = { @@ -56,8 +58,38 @@ def main(args): learning_rate=lr_schedule, parameters=model.parameters()) # stage6: build the loss function, we now only support LogSoftmaxWrapper + criterion = LogSoftmaxWrapper( + loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) + + + # stage7: confirm training start epoch + # if pre-trained model exists, start epoch confirmed by the pre-trained model + start_epoch = 0 + if args.load_checkpoint: + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + try: + # load model checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + # load optimizer checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdopt')) + optimizer.set_state_dict(state_dict) + if local_rank == 0: + print(f'Checkpoint loaded from {args.load_checkpoint}') + except FileExistsError: + if local_rank == 0: + print('Train from scratch.') + try: + start_epoch = int(args.load_checkpoint[-1]) + print(f'Restore training from epoch {start_epoch}.') + except ValueError: + pass + if __name__ == "__main__": # yapf: disable parser = argparse.ArgumentParser(__doc__) @@ -73,6 +105,11 @@ if __name__ == "__main__": type=float, default=1e-8, help="Learning rate used to train with warmup.") + parser.add_argument("--load_checkpoint", + type=str, + default=None, + help="Directory to load model checkpoint to contiune trainning.") + args = parser.parse_args() # yapf: enable diff --git a/paddleaudio/utils/download.py b/paddleaudio/utils/download.py index 45a8e57b..a0c02ee1 100644 --- a/paddleaudio/utils/download.py +++ b/paddleaudio/utils/download.py @@ -23,15 +23,29 @@ from .log import logger download.logger = logger -def decompress(file: str): +def decompress(file: str, path: str=os.PathLike): """ - Extracts all files from a compressed file. + Extracts all files from a compressed file to specific path. """ assert os.path.isfile(file), "File: {} not exists.".format(file) - download._decompress(file) + if path is None: + print("decompress the data: {}".format(file)) + download._decompress(file) + else: + print("decompress the data: {} to {}".format(file, path)) + if not os.path.isdir(path): + os.makedirs(path) -def download_and_decompress(archives: List[Dict[str, str]], path: str): + tmp_file = os.path.join(path, os.path.basename(file)) + os.rename(file, tmp_file) + download._decompress(tmp_file) + os.rename(tmp_file, file) + + +def download_and_decompress(archives: List[Dict[str, str]], + path: str, + decompress: bool=True): """ Download archieves and decompress to specific path. """ @@ -41,8 +55,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str): for archive in archives: assert 'url' in archive and 'md5' in archive, \ 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' - - download.get_path_from_url(archive['url'], path, archive['md5']) + download.get_path_from_url( + archive['url'], path, archive['md5'], decompress=decompress) def load_state_dict_from_url(url: str, path: str, md5: str=None): diff --git a/paddlespeech/vector/layers/loss.py b/paddlespeech/vector/layers/loss.py new file mode 100644 index 00000000..bf632b13 --- /dev/null +++ b/paddlespeech/vector/layers/loss.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class AngularMargin(nn.Layer): + def __init__(self, margin=0.0, scale=1.0): + super(AngularMargin, self).__init__() + self.margin = margin + self.scale = scale + + def forward(self, outputs, targets): + outputs = outputs - self.margin * targets + return self.scale * outputs + + +class AdditiveAngularMargin(AngularMargin): + def __init__(self, margin=0.0, scale=1.0, easy_margin=False): + super(AdditiveAngularMargin, self).__init__(margin, scale) + self.easy_margin = easy_margin + + self.cos_m = math.cos(self.margin) + self.sin_m = math.sin(self.margin) + self.th = math.cos(math.pi - self.margin) + self.mm = math.sin(math.pi - self.margin) * self.margin + + def forward(self, outputs, targets): + cosine = outputs.astype('float32') + sine = paddle.sqrt(1.0 - paddle.pow(cosine, 2)) + phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) + if self.easy_margin: + phi = paddle.where(cosine > 0, phi, cosine) + else: + phi = paddle.where(cosine > self.th, phi, cosine - self.mm) + outputs = (targets * phi) + ((1.0 - targets) * cosine) + return self.scale * outputs + + +class LogSoftmaxWrapper(nn.Layer): + def __init__(self, loss_fn): + super(LogSoftmaxWrapper, self).__init__() + self.loss_fn = loss_fn + self.criterion = paddle.nn.KLDivLoss(reduction="sum") + + def forward(self, outputs, targets, length=None): + targets = F.one_hot(targets, outputs.shape[1]) + try: + predictions = self.loss_fn(outputs, targets) + except TypeError: + predictions = self.loss_fn(outputs) + + predictions = F.log_softmax(predictions, axis=1) + loss = self.criterion(predictions, targets) / targets.sum() + return loss \ No newline at end of file -- GitLab