train.py 4.4 KB
Newer Older
X
xiongxinlei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import os
X
xiongxinlei 已提交
15
import argparse
X
xiongxinlei 已提交
16

X
xiongxinlei 已提交
17
import paddle
X
xiongxinlei 已提交
18

19 20 21 22
from paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.vector.layers.lr import CyclicLRScheduler
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.training.sid_model import SpeakerIdetification
23
from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper
X
xiongxinlei 已提交
24 25

def main(args):
26
    # stage0: set the training device, cpu or gpu
X
xiongxinlei 已提交
27 28 29 30 31 32 33 34
    paddle.set_device(args.device)

    # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
    paddle.distributed.init_parallel_env()
    nranks = paddle.distributed.get_world_size()
    local_rank = paddle.distributed.get_rank()

    # stage2: data prepare
35
    # note: some cmd must do in rank==0
X
xiongxinlei 已提交
36
    train_ds = VoxCeleb1('train', target_dir=args.data_dir)
37
    dev_ds = VoxCeleb1('dev', target_dir=args.data_dir)
X
xiongxinlei 已提交
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    # stage3: build the dnn backbone model network
    model_conf = {
        "input_size": 80,
        "channels": [1024, 1024, 1024, 1024, 3072],
        "kernel_sizes": [5, 3, 3, 3, 1],
        "dilations": [1, 2, 3, 4, 1],
        "attention_channels": 128,
        "lin_neurons": 192,
    }
    ecapa_tdnn = EcapaTdnn(**model_conf)

    # stage4: build the speaker verification train instance with backbone model
    model = SpeakerIdetification(
        backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers)

    # stage5: build the optimizer, we now only construct the AdamW optimizer
    lr_schedule = CyclicLRScheduler(
        base_lr=args.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_schedule, parameters=model.parameters())

    # stage6: build the loss function, we now only support LogSoftmaxWrapper
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    criterion = LogSoftmaxWrapper(
        loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))

    
    # stage7: confirm training start epoch
    #         if pre-trained model exists, start epoch confirmed by the pre-trained model
    start_epoch = 0
    if args.load_checkpoint:
        args.load_checkpoint = os.path.abspath(
            os.path.expanduser(args.load_checkpoint))
        try:
            # load model checkpoint
            state_dict = paddle.load(
                os.path.join(args.load_checkpoint, 'model.pdparams'))
            model.set_state_dict(state_dict)
76

77 78 79 80 81 82 83 84 85
            # load optimizer checkpoint
            state_dict = paddle.load(
                os.path.join(args.load_checkpoint, 'model.pdopt'))
            optimizer.set_state_dict(state_dict)
            if local_rank == 0:
                print(f'Checkpoint loaded from {args.load_checkpoint}')
        except FileExistsError:
            if local_rank == 0:
                print('Train from scratch.')
X
xiongxinlei 已提交
86

87 88 89 90 91 92
        try:
            start_epoch = int(args.load_checkpoint[-1])
            print(f'Restore training from epoch {start_epoch}.')
        except ValueError:
            pass
    
X
xiongxinlei 已提交
93 94 95
if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser(__doc__)
X
xiongxinlei 已提交
96 97 98
    parser.add_argument('--device',
                        choices=['cpu', 'gpu'],
                        default="cpu",
X
xiongxinlei 已提交
99 100 101 102 103
                        help="Select which device to train model, defaults to gpu.")
    parser.add_argument("--data-dir",
                        default="./data/",
                        type=str,
                        help="data directory")
104 105 106 107
    parser.add_argument("--learning_rate",
                        type=float,
                        default=1e-8,
                        help="Learning rate used to train with warmup.")
108 109 110 111 112
    parser.add_argument("--load_checkpoint", 
                        type=str, 
                        default=None, 
                        help="Directory to load model checkpoint to contiune trainning.")

X
xiongxinlei 已提交
113 114 115
    args = parser.parse_args()
    # yapf: enable

X
xiongxinlei 已提交
116
    main(args)