train.py 7.2 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
from functools import partial
G
guosheng 已提交
21 22 23 24 25

import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
G
guosheng 已提交
26
from paddle.fluid.io import DataLoader
G
guosheng 已提交
27 28 29 30

from utils.configure import PDConfig
from utils.check import check_gpu, check_version

31
from model import Input, set_device
G
guosheng 已提交
32
from callbacks import ProgBarLogger
G
guosheng 已提交
33 34
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
G
guosheng 已提交
35 36 37 38 39


class LoggerCallback(ProgBarLogger):
    def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
        super(LoggerCallback, self).__init__(log_freq, verbose)
G
guosheng 已提交
40
        # TODO: wrap these override function to simplify
G
guosheng 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        self.loss_normalizer = loss_normalizer

    def on_train_begin(self, logs=None):
        super(LoggerCallback, self).on_train_begin(logs)
        self.train_metrics += ["normalized loss", "ppl"]

    def on_train_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
        super(LoggerCallback, self).on_train_batch_end(step, logs)

    def on_eval_begin(self, logs=None):
        super(LoggerCallback, self).on_eval_begin(logs)
        self.eval_metrics += ["normalized loss", "ppl"]

    def on_eval_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
        super(LoggerCallback, self).on_eval_batch_end(step, logs)
G
guosheng 已提交
60 61 62


def do_train(args):
63 64
    device = set_device("gpu" if args.use_cuda else "cpu")
    fluid.enable_dygraph(device) if args.eager_run else None
G
guosheng 已提交
65 66 67 68 69 70 71

    # set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        fluid.default_main_program().random_seed = random_seed
        fluid.default_startup_program().random_seed = random_seed

G
guosheng 已提交
72
    # define inputs
G
guosheng 已提交
73
    inputs = [
G
guosheng 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        Input(
            [None, None], "int64", name="src_word"),
        Input(
            [None, None], "int64", name="src_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="src_slf_attn_bias"),
        Input(
            [None, None], "int64", name="trg_word"),
        Input(
            [None, None], "int64", name="trg_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_slf_attn_bias"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_src_attn_bias"),
G
guosheng 已提交
94 95 96 97 98 99 100 101
    ]
    labels = [
        Input(
            [None, 1], "int64", name="label"),
        Input(
            [None, 1], "float32", name="weight"),
    ]

G
guosheng 已提交
102 103 104 105 106
    # def dataloader
    data_loaders = [None, None]
    data_files = [args.training_file, args.validation_file
                  ] if args.validation_file else [args.training_file]
    for i, data_file in enumerate(data_files):
G
guosheng 已提交
107 108 109 110 111 112 113 114
        dataset = Seq2SeqDataset(
            fpattern=data_file,
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            token_delimiter=args.token_delimiter,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2])
G
guosheng 已提交
115 116
        args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
            args.unk_idx = dataset.get_vocab_summary()
G
guosheng 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        batch_sampler = Seq2SeqBatchSampler(
            dataset=dataset,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=args.shuffle,
            shuffle_batch=args.shuffle_batch,
            max_length=args.max_length)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=batch_sampler,
            places=device,
            feed_list=None if fluid.in_dygraph_mode() else
            [x.forward() for x in inputs + labels],
            collate_fn=partial(
                prepare_train_input,
                src_pad_idx=args.eos_idx,
                trg_pad_idx=args.eos_idx,
                n_head=args.n_head),
            num_workers=0,  # TODO: use multi-process
            return_list=True)
G
guosheng 已提交
139 140
        data_loaders[i] = data_loader
    train_loader, eval_loader = data_loaders
141

G
guosheng 已提交
142
    # define model
G
guosheng 已提交
143 144 145 146 147 148 149
    transformer = Transformer(
        args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
        args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
        args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
        args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
        args.weight_sharing, args.bos_idx, args.eos_idx)

G
guosheng 已提交
150 151 152 153 154 155 156 157 158 159 160
    transformer.prepare(
        fluid.optimizer.Adam(
            learning_rate=fluid.layers.noam_decay(args.d_model,
                                                  args.warmup_steps),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters()),
        CrossEntropyCriterion(args.label_smooth_eps),
        inputs=inputs,
        labels=labels)
G
guosheng 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        transformer.load(
            os.path.join(args.init_from_checkpoint, "transformer"))
    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        transformer.load(
            os.path.join(args.init_from_pretrain_model, "transformer"),
            reset_optimizer=True)

    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))

G
guosheng 已提交
178
    # model train
179
    transformer.fit(train_data=train_loader,
G
guosheng 已提交
180
                    eval_data=eval_loader,
G
guosheng 已提交
181 182 183 184 185 186 187 188 189
                    epochs=1,
                    eval_freq=1,
                    save_freq=1,
                    verbose=2,
                    callbacks=[
                        LoggerCallback(
                            log_freq=args.print_step,
                            loss_normalizer=loss_normalizer)
                    ])
G
guosheng 已提交
190 191 192 193 194 195 196 197 198 199


if __name__ == "__main__":
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
    check_gpu(args.use_cuda)
    check_version()

    do_train(args)