train.py 6.1 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
17
import random
Z
Zeyu Chen 已提交
18

19
import numpy as np
Z
Zeyu Chen 已提交
20 21
import paddle
import paddlenlp as ppnlp
22
from paddlenlp.data import JiebaTokenizer, Pad, Stack, Tuple, Vocab
S
Steffy-zxf 已提交
23
from paddlenlp.datasets import ChnSentiCorp
Z
Zeyu Chen 已提交
24

25
from utils import convert_example
Z
Zeyu Chen 已提交
26 27 28

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
29
parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.")
S
Steffy-zxf 已提交
30
parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False")
31
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate used to train.")
32
parser.add_argument("--save_dir", type=str, default='checkpoints/', help="Directory to save model checkpoint")
Z
Zeyu Chen 已提交
33
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
34
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The directory to dataset.")
35
parser.add_argument('--network', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
Z
Zeyu Chen 已提交
36 37 38 39 40
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
args = parser.parse_args()
# yapf: enable


41 42 43 44 45 46 47
def set_seed(seed=1000):
    """sets random seed"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)


Z
Zeyu Chen 已提交
48 49 50 51 52
def create_dataloader(dataset,
                      trans_fn=None,
                      mode='train',
                      batch_size=1,
                      use_gpu=False,
S
Steffy-zxf 已提交
53
                      batchify_fn=None):
Z
Zeyu Chen 已提交
54 55 56 57 58
    """
    Creats dataloader.

    Args:
        dataset(obj:`paddle.io.Dataset`): Dataset instance.
S
Steffy-zxf 已提交
59
        trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
Z
Zeyu Chen 已提交
60 61 62
        mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
        batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
        use_gpu(obj:`bool`, optional, defaults to obj:`False`): Whether to use gpu to run.
S
Steffy-zxf 已提交
63 64 65
        batchify_fn(obj:`callable`, optional, defaults to `None`): function to generate mini-batch data by merging
            the sample list, None for only stack each fields of sample in axis
            0(same as :attr::`np.stack(..., axis=0)`).
Z
Zeyu Chen 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79

    Returns:
        dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
    """
    if trans_fn:
        dataset = dataset.apply(trans_fn, lazy=True)

    if mode == 'train' and use_gpu:
        sampler = paddle.io.DistributedBatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=True)
    else:
        shuffle = True if mode == 'train' else False
        sampler = paddle.io.BatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=shuffle)
S
Steffy-zxf 已提交
80 81 82 83
    dataloader = paddle.io.DataLoader(
        dataset,
        batch_sampler=sampler,
        return_list=True,
S
Steffy-zxf 已提交
84
        collate_fn=batchify_fn)
Z
Zeyu Chen 已提交
85 86 87 88
    return dataloader


if __name__ == "__main__":
89
    set_seed()
S
Steffy-zxf 已提交
90
    paddle.set_device('gpu') if args.use_gpu else paddle.set_device('cpu')
Z
Zeyu Chen 已提交
91 92 93 94 95 96

    # Loads vocab.
    if not os.path.exists(args.vocab_path):
        raise RuntimeError('The vocab_path  can not be found in the path %s' %
                           args.vocab_path)

97 98
    vocab = Vocab.load_vocabulary(
        args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
Z
Zeyu Chen 已提交
99
    # Loads dataset.
S
Steffy-zxf 已提交
100
    train_ds, dev_ds, test_ds = ChnSentiCorp.get_datasets(
Z
Zeyu Chen 已提交
101 102 103
        ['train', 'dev', 'test'])

    # Constructs the newtork.
S
Steffy-zxf 已提交
104 105
    label_list = train_ds.get_labels()
    model = ppnlp.models.Senta(
S
Steffy-zxf 已提交
106
        network=args.network,
S
Steffy-zxf 已提交
107 108 109
        vocab_size=len(vocab),
        num_classes=len(label_list))
    model = paddle.Model(model)
Z
Zeyu Chen 已提交
110 111

    # Reads data and generates mini-batches.
112 113
    tokenizer = JiebaTokenizer(vocab)
    trans_fn = partial(convert_example, tokenizer=tokenizer, is_test=False)
S
Steffy-zxf 已提交
114
    batchify_fn = lambda samples, fn=Tuple(
115
        Pad(axis=0, pad_val=vocab.token_to_idx.get('[PAD]', 0)),  # input_ids
S
Steffy-zxf 已提交
116 117 118
        Stack(dtype="int64"),  # seq len
        Stack(dtype="int64")  # label
    ): [data for data in fn(samples)]
Z
Zeyu Chen 已提交
119
    train_loader = create_dataloader(
S
Steffy-zxf 已提交
120 121 122 123 124 125
        train_ds,
        trans_fn=trans_fn,
        batch_size=args.batch_size,
        mode='train',
        use_gpu=args.use_gpu,
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
126
    dev_loader = create_dataloader(
S
Steffy-zxf 已提交
127
        dev_ds,
Z
Zeyu Chen 已提交
128
        trans_fn=trans_fn,
S
Steffy-zxf 已提交
129
        batch_size=args.batch_size,
S
Steffy-zxf 已提交
130 131 132
        mode='validation',
        use_gpu=args.use_gpu,
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
133
    test_loader = create_dataloader(
S
Steffy-zxf 已提交
134 135 136 137 138 139
        test_ds,
        trans_fn=trans_fn,
        batch_size=args.batch_size,
        mode='test',
        use_gpu=args.use_gpu,
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
140 141

    optimizer = paddle.optimizer.Adam(
S
Steffy-zxf 已提交
142
        parameters=model.parameters(), learning_rate=args.lr)
Z
Zeyu Chen 已提交
143 144 145

    # Defines loss and metric.
    criterion = paddle.nn.CrossEntropyLoss()
S
Steffy-zxf 已提交
146
    metric = paddle.metric.Accuracy()
Z
Zeyu Chen 已提交
147 148 149 150 151 152 153 154 155

    model.prepare(optimizer, criterion, metric)

    # Loads pre-trained parameters.
    if args.init_from_ckpt:
        model.load(args.init_from_ckpt)
        print("Loaded checkpoint from %s" % args.init_from_ckpt)

    # Starts training and evaluating.
S
Steffy-zxf 已提交
156
    callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
Z
Zeyu Chen 已提交
157 158
    model.fit(train_loader,
              dev_loader,
S
Steffy-zxf 已提交
159
              epochs=args.epochs,
S
Steffy-zxf 已提交
160 161
              save_dir=args.save_dir,
              callbacks=callback)
Z
Zeyu Chen 已提交
162 163 164

    # Finally tests model.
    results = model.evaluate(test_loader)
S
Steffy-zxf 已提交
165
    print("Finally test acc: %.5f" % results['acc'])