train.py 6.3 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
17
import random
Z
Zeyu Chen 已提交
18

19
import numpy as np
Z
Zeyu Chen 已提交
20 21
import paddle
import paddlenlp as ppnlp
S
Steffy-zxf 已提交
22
from paddlenlp.data import Stack, Tuple, Pad
S
Steffy-zxf 已提交
23
from paddlenlp.datasets import ChnSentiCorp
Z
Zeyu Chen 已提交
24

S
Steffy-zxf 已提交
25
from utils import load_vocab, convert_example
Z
Zeyu Chen 已提交
26 27 28

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
29
parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.")
S
Steffy-zxf 已提交
30
parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False")
31
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate used to train.")
Z
Zeyu Chen 已提交
32 33
parser.add_argument("--save_dir", type=str, default='chekpoints/', help="Directory to save model checkpoint")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
34
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The directory to dataset.")
S
Steffy-zxf 已提交
35
parser.add_argument('--network', type=str, default="bilstm_attn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
Z
Zeyu Chen 已提交
36 37 38 39 40
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
args = parser.parse_args()
# yapf: enable


41 42 43 44 45 46 47
def set_seed(seed=1000):
    """sets random seed"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)


Z
Zeyu Chen 已提交
48 49 50 51 52
def create_dataloader(dataset,
                      trans_fn=None,
                      mode='train',
                      batch_size=1,
                      use_gpu=False,
S
Steffy-zxf 已提交
53 54
                      pad_token_id=0,
                      batchify_fn=None):
Z
Zeyu Chen 已提交
55 56 57 58 59
    """
    Creats dataloader.

    Args:
        dataset(obj:`paddle.io.Dataset`): Dataset instance.
S
Steffy-zxf 已提交
60
        trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
Z
Zeyu Chen 已提交
61 62 63 64
        mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
        batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
        use_gpu(obj:`bool`, optional, defaults to obj:`False`): Whether to use gpu to run.
        pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
S
Steffy-zxf 已提交
65 66 67
        batchify_fn(obj:`callable`, optional, defaults to `None`): function to generate mini-batch data by merging
            the sample list, None for only stack each fields of sample in axis
            0(same as :attr::`np.stack(..., axis=0)`).
Z
Zeyu Chen 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81

    Returns:
        dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
    """
    if trans_fn:
        dataset = dataset.apply(trans_fn, lazy=True)

    if mode == 'train' and use_gpu:
        sampler = paddle.io.DistributedBatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=True)
    else:
        shuffle = True if mode == 'train' else False
        sampler = paddle.io.BatchSampler(
            dataset=dataset, batch_size=batch_size, shuffle=shuffle)
S
Steffy-zxf 已提交
82 83 84 85
    dataloader = paddle.io.DataLoader(
        dataset,
        batch_sampler=sampler,
        return_list=True,
S
Steffy-zxf 已提交
86
        collate_fn=batchify_fn)
Z
Zeyu Chen 已提交
87 88 89 90
    return dataloader


if __name__ == "__main__":
91
    set_seed()
S
Steffy-zxf 已提交
92
    paddle.set_device('gpu') if args.use_gpu else paddle.set_device('cpu')
Z
Zeyu Chen 已提交
93 94 95 96 97 98 99 100

    # Loads vocab.
    if not os.path.exists(args.vocab_path):
        raise RuntimeError('The vocab_path  can not be found in the path %s' %
                           args.vocab_path)
    vocab = load_vocab(args.vocab_path)

    # Loads dataset.
S
Steffy-zxf 已提交
101
    train_ds, dev_ds, test_ds = ChnSentiCorp.get_datasets(
Z
Zeyu Chen 已提交
102 103 104
        ['train', 'dev', 'test'])

    # Constructs the newtork.
S
Steffy-zxf 已提交
105 106
    label_list = train_ds.get_labels()
    model = ppnlp.models.Senta(
S
Steffy-zxf 已提交
107
        network=args.network,
S
Steffy-zxf 已提交
108 109 110
        vocab_size=len(vocab),
        num_classes=len(label_list))
    model = paddle.Model(model)
Z
Zeyu Chen 已提交
111 112 113 114 115

    # Reads data and generates mini-batches.
    trans_fn = partial(
        convert_example,
        vocab=vocab,
116
        unk_token_id=vocab.get('[UNK]', 1),
Z
Zeyu Chen 已提交
117
        is_test=False)
S
Steffy-zxf 已提交
118 119 120 121 122
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=vocab['[PAD]']),  # input_ids
        Stack(dtype="int64"),  # seq len
        Stack(dtype="int64")  # label
    ): [data for data in fn(samples)]
Z
Zeyu Chen 已提交
123
    train_loader = create_dataloader(
S
Steffy-zxf 已提交
124 125 126 127 128
        train_ds,
        trans_fn=trans_fn,
        batch_size=args.batch_size,
        mode='train',
        use_gpu=args.use_gpu,
129
        pad_token_id=vocab.get('[PAD]', 0),
S
Steffy-zxf 已提交
130
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
131
    dev_loader = create_dataloader(
S
Steffy-zxf 已提交
132
        dev_ds,
Z
Zeyu Chen 已提交
133
        trans_fn=trans_fn,
S
Steffy-zxf 已提交
134
        batch_size=args.batch_size,
S
Steffy-zxf 已提交
135 136
        mode='validation',
        use_gpu=args.use_gpu,
137
        pad_token_id=vocab.get('[PAD]', 0),
S
Steffy-zxf 已提交
138
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
139
    test_loader = create_dataloader(
S
Steffy-zxf 已提交
140 141 142 143 144
        test_ds,
        trans_fn=trans_fn,
        batch_size=args.batch_size,
        mode='test',
        use_gpu=args.use_gpu,
145
        pad_token_id=vocab.get('[PAD]', 0),
S
Steffy-zxf 已提交
146
        batchify_fn=batchify_fn)
Z
Zeyu Chen 已提交
147 148

    optimizer = paddle.optimizer.Adam(
S
Steffy-zxf 已提交
149
        parameters=model.parameters(), learning_rate=args.lr)
Z
Zeyu Chen 已提交
150 151 152

    # Defines loss and metric.
    criterion = paddle.nn.CrossEntropyLoss()
S
Steffy-zxf 已提交
153
    metric = paddle.metric.Accuracy()
Z
Zeyu Chen 已提交
154 155 156 157 158 159 160 161 162

    model.prepare(optimizer, criterion, metric)

    # Loads pre-trained parameters.
    if args.init_from_ckpt:
        model.load(args.init_from_ckpt)
        print("Loaded checkpoint from %s" % args.init_from_ckpt)

    # Starts training and evaluating.
S
Steffy-zxf 已提交
163
    callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3)
Z
Zeyu Chen 已提交
164 165
    model.fit(train_loader,
              dev_loader,
S
Steffy-zxf 已提交
166
              epochs=args.epochs,
S
Steffy-zxf 已提交
167 168
              save_dir=args.save_dir,
              callbacks=callback)
Z
Zeyu Chen 已提交
169 170 171

    # Finally tests model.
    results = model.evaluate(test_loader)
S
Steffy-zxf 已提交
172
    print("Finally test acc: %.5f" % results['acc'])