diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py index 1d9a78f98c1c2658890733d023c9310f063ea1c2..bddb94bb2dfbaa896f82522764db6df767ea819e 100644 --- a/examples/voxceleb/sv0/local/train.py +++ b/examples/voxceleb/sv0/local/train.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os import paddle +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler from paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddlespeech.vector.datasets.batch import waveform_collate_fn +from paddlespeech.vector.layers.loss import AdditiveAngularMargin +from paddlespeech.vector.layers.loss import LogSoftmaxWrapper from paddlespeech.vector.layers.lr import CyclicLRScheduler from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.training.sid_model import SpeakerIdetification -from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper + def main(args): # stage0: set the training device, cpu or gpu @@ -61,7 +66,6 @@ def main(args): criterion = LogSoftmaxWrapper( loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) - # stage7: confirm training start epoch # if pre-trained model exists, start epoch confirmed by the pre-trained model start_epoch = 0 @@ -89,7 +93,19 @@ def main(args): print(f'Restore training from epoch {start_epoch}.') except ValueError: pass - + + # stage8: we build the batch sampler for paddle.DataLoader + train_sampler = DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False) + train_loader = DataLoader( + train_ds, + batch_sampler=train_sampler, + num_workers=args.num_workers, + collate_fn=waveform_collate_fn, + return_list=True, + use_buffer_reader=True, ) + + if __name__ == "__main__": # yapf: disable parser = argparse.ArgumentParser(__doc__) @@ -105,10 +121,17 @@ if __name__ == "__main__": type=float, default=1e-8, help="Learning rate used to train with warmup.") - parser.add_argument("--load_checkpoint", - type=str, - default=None, + parser.add_argument("--load_checkpoint", + type=str, + default=None, help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--batch_size", + type=int, default=64, + help="Total examples' number in batch for training.") + parser.add_argument("--num_workers", + type=int, + default=0, + help="Number of workers in dataloader.") args = parser.parse_args() # yapf: enable diff --git a/paddlespeech/vector/datasets/batch.py b/paddlespeech/vector/datasets/batch.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e5d6ee325b57869311e527e23f5afd694457d6 --- /dev/null +++ b/paddlespeech/vector/datasets/batch.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def waveform_collate_fn(batch): + waveforms = np.stack([item['feat'] for item in batch]) + labels = np.stack([item['label'] for item in batch]) + + return {'waveforms': waveforms, 'labels': labels}