train.py 4.4 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import ast
import math
import argparse

import numpy as np
import paddle

K
kinghuin 已提交
23
from data import LacDataset
Z
Zeyu Chen 已提交
24
from model import BiGruCrf
K
kinghuin 已提交
25
from paddlenlp.data import Pad, Tuple, Stack
Z
Zeyu Chen 已提交
26
from paddlenlp.layers.crf import LinearChainCrfLoss, ViterbiDecoder
K
kinghuin 已提交
27
from paddlenlp.metrics import ChunkEvaluator
Z
Zeyu Chen 已提交
28 29 30

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
K
kinghuin 已提交
31
parser.add_argument("--data_dir", type=str, default=None, help="The folder where the dataset is located.")
Z
Zeyu Chen 已提交
32 33 34 35 36
parser.add_argument("--init_checkpoint", type=str, default=None, help="Path to init model.")
parser.add_argument("--model_save_dir", type=str, default=None, help="The model will be saved in this path.")
parser.add_argument("--epochs", type=int, default=10, help="Corpus iteration num.")
parser.add_argument("--batch_size", type=int, default=300, help="The number of sequences contained in a mini-batch.")
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
37
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs to use, 0 for CPU.")
Z
Zeyu Chen 已提交
38 39 40
parser.add_argument("--base_lr", type=float, default=0.001, help="The basic learning rate that affects the entire network.")
parser.add_argument("--emb_dim", type=int, default=128, help="The dimension in which a word is embedded.")
parser.add_argument("--hidden_size", type=int, default=128, help="The number of hidden nodes in the GRU layer.")
41
parser.add_argument("--verbose", type=ast.literal_eval, default=128, help="Print reader and training time in details.")
Z
Zeyu Chen 已提交
42 43 44 45
# yapf: enable


def train(args):
46
    paddle.set_device("gpu" if args.n_gpu else "cpu")
Z
Zeyu Chen 已提交
47

K
kinghuin 已提交
48
    # create dataset.
K
kinghuin 已提交
49 50
    train_dataset = LacDataset(args.data_dir, mode='train')
    test_dataset = LacDataset(args.data_dir, mode='test')
K
kinghuin 已提交
51 52 53 54 55 56

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=0),  # word_ids
        Stack(),  # length
        Pad(axis=0, pad_val=0),  # label_ids
    ): fn(samples)
Z
Zeyu Chen 已提交
57 58 59 60 61 62 63 64 65 66 67

    # Create sampler for dataloader
    train_sampler = paddle.io.DistributedBatchSampler(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(
        dataset=train_dataset,
        batch_sampler=train_sampler,
        return_list=True,
K
kinghuin 已提交
68
        collate_fn=batchify_fn)
Z
Zeyu Chen 已提交
69 70 71 72 73

    test_sampler = paddle.io.BatchSampler(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
K
kinghuin 已提交
74
        drop_last=False)
Z
Zeyu Chen 已提交
75 76 77 78
    test_loader = paddle.io.DataLoader(
        dataset=test_dataset,
        batch_sampler=test_sampler,
        return_list=True,
K
kinghuin 已提交
79
        collate_fn=batchify_fn)
Z
Zeyu Chen 已提交
80 81 82 83 84 85 86 87 88

    # Define the model netword and its loss
    network = BiGruCrf(args.emb_dim, args.hidden_size, train_dataset.vocab_size,
                       train_dataset.num_labels)
    model = paddle.Model(network)

    # Prepare optimizer, loss and metric evaluator
    optimizer = paddle.optimizer.Adam(
        learning_rate=args.base_lr, parameters=model.parameters())
K
kinghuin 已提交
89
    crf_loss = LinearChainCrfLoss(network.crf)
Z
Zeyu Chen 已提交
90
    chunk_evaluator = ChunkEvaluator(
K
kinghuin 已提交
91
        label_list=train_dataset.label_vocab.keys(), suffix=True)
Z
Zeyu Chen 已提交
92 93 94 95 96
    model.prepare(optimizer, crf_loss, chunk_evaluator)
    if args.init_checkpoint:
        model.load(args.init_checkpoint)

    # Start training
97 98
    callbacks = paddle.callbacks.ProgBarLogger(
        log_freq=10, verbose=3) if args.verbose else None
Z
Zeyu Chen 已提交
99 100 101 102 103
    model.fit(train_data=train_loader,
              eval_data=test_loader,
              batch_size=args.batch_size,
              epochs=args.epochs,
              eval_freq=1,
K
kinghuin 已提交
104
              log_freq=10,
Z
Zeyu Chen 已提交
105 106
              save_dir=args.model_save_dir,
              save_freq=1,
107 108
              shuffle=True,
              callbacks=callbacks)
Z
Zeyu Chen 已提交
109 110 111


if __name__ == "__main__":
112 113 114 115 116
    args = parser.parse_args()
    if args.n_gpu > 1:
        paddle.distributed.spawn(train, args=(args, ), nprocs=args.n_gpu)
    else:
        train(args)