train.py 5.3 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np
import paddle
import paddle.fluid as fluid
D
dengkaipeng 已提交
24
from paddle.io import DataLoader
G
guosheng 已提交
25 26 27 28

from utils.configure import PDConfig
from utils.check import check_gpu, check_version

29
from model import Input, set_device
G
guosheng 已提交
30
from callbacks import ProgBarLogger
G
guosheng 已提交
31
from reader import create_data_loader
G
guosheng 已提交
32
from transformer import Transformer, CrossEntropyCriterion
G
guosheng 已提交
33 34


G
guosheng 已提交
35
class TrainCallback(ProgBarLogger):
G
guosheng 已提交
36
    def __init__(self, args, verbose=2):
37
        # TODO(guosheng): save according to step
G
guosheng 已提交
38 39 40 41 42 43
        super(TrainCallback, self).__init__(args.print_step, verbose)
        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
            (1. - args.label_smooth_eps) * np.log(
                (1. - args.label_smooth_eps)) + args.label_smooth_eps *
            np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
44 45 46
        self.loss_normalizer = loss_normalizer

    def on_train_begin(self, logs=None):
G
guosheng 已提交
47
        super(TrainCallback, self).on_train_begin(logs)
G
guosheng 已提交
48 49 50 51 52
        self.train_metrics += ["normalized loss", "ppl"]

    def on_train_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
53
        super(TrainCallback, self).on_train_batch_end(step, logs)
G
guosheng 已提交
54 55

    def on_eval_begin(self, logs=None):
G
guosheng 已提交
56 57 58
        super(TrainCallback, self).on_eval_begin(logs)
        self.eval_metrics = list(
            self.eval_metrics) + ["normalized loss", "ppl"]
G
guosheng 已提交
59 60 61 62

    def on_eval_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
63
        super(TrainCallback, self).on_eval_batch_end(step, logs)
G
guosheng 已提交
64 65 66


def do_train(args):
67 68
    device = set_device("gpu" if args.use_cuda else "cpu")
    fluid.enable_dygraph(device) if args.eager_run else None
G
guosheng 已提交
69 70 71 72 73 74 75

    # set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        fluid.default_main_program().random_seed = random_seed
        fluid.default_startup_program().random_seed = random_seed

G
guosheng 已提交
76
    # define inputs
G
guosheng 已提交
77
    inputs = [
G
guosheng 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
        Input(
            [None, None], "int64", name="src_word"),
        Input(
            [None, None], "int64", name="src_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="src_slf_attn_bias"),
        Input(
            [None, None], "int64", name="trg_word"),
        Input(
            [None, None], "int64", name="trg_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_slf_attn_bias"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_src_attn_bias"),
G
guosheng 已提交
98 99 100 101 102 103 104 105
    ]
    labels = [
        Input(
            [None, 1], "int64", name="label"),
        Input(
            [None, 1], "float32", name="weight"),
    ]

G
guosheng 已提交
106
    # def dataloader
G
guosheng 已提交
107
    train_loader, eval_loader = create_data_loader(args, device)
108

G
guosheng 已提交
109
    # define model
G
guosheng 已提交
110 111 112 113 114 115 116
    transformer = Transformer(
        args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
        args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
        args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
        args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
        args.weight_sharing, args.bos_idx, args.eos_idx)

G
guosheng 已提交
117 118
    transformer.prepare(
        fluid.optimizer.Adam(
G
guosheng 已提交
119 120 121 122
            learning_rate=fluid.layers.noam_decay(
                args.d_model,
                args.warmup_steps,
                learning_rate=args.learning_rate),
G
guosheng 已提交
123 124 125 126 127 128 129
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters()),
        CrossEntropyCriterion(args.label_smooth_eps),
        inputs=inputs,
        labels=labels)
G
guosheng 已提交
130 131 132

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
G
guosheng 已提交
133
        transformer.load(args.init_from_checkpoint)
G
guosheng 已提交
134 135
    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
G
guosheng 已提交
136
        transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
G
guosheng 已提交
137

G
guosheng 已提交
138
    # model train
139
    transformer.fit(train_data=train_loader,
G
guosheng 已提交
140
                    eval_data=eval_loader,
G
guosheng 已提交
141
                    epochs=args.epoch,
G
guosheng 已提交
142 143
                    eval_freq=1,
                    save_freq=1,
G
guosheng 已提交
144
                    save_dir=args.save_model,
G
guosheng 已提交
145
                    callbacks=[TrainCallback(args)])
G
guosheng 已提交
146 147 148 149 150 151 152 153 154 155


if __name__ == "__main__":
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
    check_gpu(args.use_cuda)
    check_version()

    do_train(args)