train.py 7.0 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
from functools import partial
G
guosheng 已提交
21 22 23 24

import numpy as np
import paddle
import paddle.fluid as fluid
G
guosheng 已提交
25
from paddle.fluid.io import DataLoader
G
guosheng 已提交
26 27 28 29

from utils.configure import PDConfig
from utils.check import check_gpu, check_version

30
from model import Input, set_device
G
guosheng 已提交
31
from callbacks import ProgBarLogger
G
guosheng 已提交
32
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
G
guosheng 已提交
33
from transformer import Transformer, CrossEntropyCriterion
G
guosheng 已提交
34 35


G
guosheng 已提交
36
class TrainCallback(ProgBarLogger):
G
guosheng 已提交
37
    def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
G
guosheng 已提交
38
        super(TrainCallback, self).__init__(log_freq, verbose)
G
guosheng 已提交
39
        # TODO: wrap these override function to simplify
G
guosheng 已提交
40 41 42
        self.loss_normalizer = loss_normalizer

    def on_train_begin(self, logs=None):
G
guosheng 已提交
43
        super(TrainCallback, self).on_train_begin(logs)
G
guosheng 已提交
44 45 46 47 48
        self.train_metrics += ["normalized loss", "ppl"]

    def on_train_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
49
        super(TrainCallback, self).on_train_batch_end(step, logs)
G
guosheng 已提交
50 51

    def on_eval_begin(self, logs=None):
G
guosheng 已提交
52 53 54
        super(TrainCallback, self).on_eval_begin(logs)
        self.eval_metrics = list(
            self.eval_metrics) + ["normalized loss", "ppl"]
G
guosheng 已提交
55 56 57 58

    def on_eval_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
59
        super(TrainCallback, self).on_eval_batch_end(step, logs)
G
guosheng 已提交
60 61 62


def do_train(args):
63 64
    device = set_device("gpu" if args.use_cuda else "cpu")
    fluid.enable_dygraph(device) if args.eager_run else None
G
guosheng 已提交
65 66 67 68 69 70 71

    # set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        fluid.default_main_program().random_seed = random_seed
        fluid.default_startup_program().random_seed = random_seed

G
guosheng 已提交
72
    # define inputs
G
guosheng 已提交
73
    inputs = [
G
guosheng 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        Input(
            [None, None], "int64", name="src_word"),
        Input(
            [None, None], "int64", name="src_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="src_slf_attn_bias"),
        Input(
            [None, None], "int64", name="trg_word"),
        Input(
            [None, None], "int64", name="trg_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_slf_attn_bias"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_src_attn_bias"),
G
guosheng 已提交
94 95 96 97 98 99 100 101
    ]
    labels = [
        Input(
            [None, 1], "int64", name="label"),
        Input(
            [None, 1], "float32", name="weight"),
    ]

G
guosheng 已提交
102 103 104 105 106
    # def dataloader
    data_loaders = [None, None]
    data_files = [args.training_file, args.validation_file
                  ] if args.validation_file else [args.training_file]
    for i, data_file in enumerate(data_files):
G
guosheng 已提交
107 108 109 110 111 112 113 114
        dataset = Seq2SeqDataset(
            fpattern=data_file,
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            token_delimiter=args.token_delimiter,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2])
G
guosheng 已提交
115 116
        args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
            args.unk_idx = dataset.get_vocab_summary()
G
guosheng 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        batch_sampler = Seq2SeqBatchSampler(
            dataset=dataset,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=args.shuffle,
            shuffle_batch=args.shuffle_batch,
            max_length=args.max_length)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=batch_sampler,
            places=device,
            collate_fn=partial(
                prepare_train_input,
                src_pad_idx=args.eos_idx,
                trg_pad_idx=args.eos_idx,
                n_head=args.n_head),
            num_workers=0,  # TODO: use multi-process
            return_list=True)
G
guosheng 已提交
137 138
        data_loaders[i] = data_loader
    train_loader, eval_loader = data_loaders
139

G
guosheng 已提交
140
    # define model
G
guosheng 已提交
141 142 143 144 145 146 147
    transformer = Transformer(
        args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
        args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
        args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
        args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
        args.weight_sharing, args.bos_idx, args.eos_idx)

G
guosheng 已提交
148 149
    transformer.prepare(
        fluid.optimizer.Adam(
G
guosheng 已提交
150 151 152 153
            learning_rate=fluid.layers.noam_decay(
                args.d_model,
                args.warmup_steps,
                learning_rate=args.learning_rate),
G
guosheng 已提交
154 155 156 157 158 159 160
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters()),
        CrossEntropyCriterion(args.label_smooth_eps),
        inputs=inputs,
        labels=labels)
G
guosheng 已提交
161 162 163

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
G
guosheng 已提交
164
        transformer.load(args.init_from_checkpoint)
G
guosheng 已提交
165 166
    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
G
guosheng 已提交
167
        transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
G
guosheng 已提交
168 169 170 171 172 173 174

    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))

G
guosheng 已提交
175
    # model train
176
    transformer.fit(train_data=train_loader,
G
guosheng 已提交
177
                    eval_data=eval_loader,
G
guosheng 已提交
178
                    epochs=args.epoch,
G
guosheng 已提交
179 180
                    eval_freq=1,
                    save_freq=1,
G
guosheng 已提交
181
                    save_dir=args.save_model,
G
guosheng 已提交
182 183
                    verbose=2,
                    callbacks=[
G
guosheng 已提交
184
                        TrainCallback(
G
guosheng 已提交
185 186 187
                            log_freq=args.print_step,
                            loss_normalizer=loss_normalizer)
                    ])
G
guosheng 已提交
188 189 190 191 192 193 194 195 196 197


if __name__ == "__main__":
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
    check_gpu(args.use_cuda)
    check_version()

    do_train(args)