train.py 5.9 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import numpy as np
import paddle
import paddle.fluid as fluid
D
dengkaipeng 已提交
21
from paddle.io import DataLoader
22
from paddle.static import InputSpec as Input
G
guosheng 已提交
23 24 25 26

from utils.configure import PDConfig
from utils.check import check_gpu, check_version

G
guosheng 已提交
27
from reader import create_data_loader
G
guosheng 已提交
28
from transformer import Transformer, CrossEntropyCriterion
G
guosheng 已提交
29 30


31
class TrainCallback(paddle.callbacks.ProgBarLogger):
32 33 34 35 36
    def __init__(self,
                 args,
                 verbose=2,
                 train_steps_fn=None,
                 eval_steps_fn=None):
37
        # TODO(guosheng): save according to step
G
guosheng 已提交
38 39 40 41 42 43
        super(TrainCallback, self).__init__(args.print_step, verbose)
        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
            (1. - args.label_smooth_eps) * np.log(
                (1. - args.label_smooth_eps)) + args.label_smooth_eps *
            np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
44
        self.loss_normalizer = loss_normalizer
45 46
        self.train_steps_fn = train_steps_fn
        self.eval_steps_fn = eval_steps_fn
G
guosheng 已提交
47 48

    def on_train_begin(self, logs=None):
G
guosheng 已提交
49
        super(TrainCallback, self).on_train_begin(logs)
G
guosheng 已提交
50 51
        self.train_metrics += ["normalized loss", "ppl"]

52 53 54 55
    def on_train_batch_begin(self, step, logs=None):
        if step == 0 and self.train_steps_fn:
            self.train_progbar._num = self.train_steps_fn()

G
guosheng 已提交
56 57 58
    def on_train_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
59
        super(TrainCallback, self).on_train_batch_end(step, logs)
G
guosheng 已提交
60 61

    def on_eval_begin(self, logs=None):
G
guosheng 已提交
62 63 64
        super(TrainCallback, self).on_eval_begin(logs)
        self.eval_metrics = list(
            self.eval_metrics) + ["normalized loss", "ppl"]
G
guosheng 已提交
65

66 67 68 69
    def on_eval_batch_begin(self, step, logs=None):
        if step == 0 and self.eval_steps_fn:
            self.eval_progbar._num = self.eval_steps_fn()

G
guosheng 已提交
70 71 72
    def on_eval_batch_end(self, step, logs=None):
        logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
        logs["ppl"] = np.exp(min(logs["loss"][0], 100))
G
guosheng 已提交
73
        super(TrainCallback, self).on_eval_batch_end(step, logs)
G
guosheng 已提交
74 75 76


def do_train(args):
77
    device = paddle.set_device("gpu" if args.use_cuda else "cpu")
78
    fluid.enable_dygraph(device) if args.eager_run else None
G
guosheng 已提交
79 80 81 82 83 84 85

    # set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        fluid.default_main_program().random_seed = random_seed
        fluid.default_startup_program().random_seed = random_seed

G
guosheng 已提交
86
    # define inputs
G
guosheng 已提交
87
    inputs = [
G
guosheng 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        Input(
            [None, None], "int64", name="src_word"),
        Input(
            [None, None], "int64", name="src_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="src_slf_attn_bias"),
        Input(
            [None, None], "int64", name="trg_word"),
        Input(
            [None, None], "int64", name="trg_pos"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_slf_attn_bias"),
        Input(
            [None, args.n_head, None, None],
            "float32",
            name="trg_src_attn_bias"),
G
guosheng 已提交
108 109 110 111 112 113 114 115
    ]
    labels = [
        Input(
            [None, 1], "int64", name="label"),
        Input(
            [None, 1], "float32", name="weight"),
    ]

G
guosheng 已提交
116
    # def dataloader
117 118
    (train_loader, train_steps_fn), (
        eval_loader, eval_steps_fn) = create_data_loader(args, device)
119

G
guosheng 已提交
120
    # define model
121 122 123 124 125 126 127 128 129 130
    model = paddle.Model(
        Transformer(args.src_vocab_size, args.trg_vocab_size,
                    args.max_length + 1, args.n_layer, args.n_head, args.d_key,
                    args.d_value, args.d_model, args.d_inner_hid,
                    args.prepostprocess_dropout, args.attention_dropout,
                    args.relu_dropout, args.preprocess_cmd,
                    args.postprocess_cmd, args.weight_sharing, args.bos_idx,
                    args.eos_idx), inputs, labels)

    model.prepare(
G
guosheng 已提交
131
        fluid.optimizer.Adam(
G
guosheng 已提交
132 133 134 135
            learning_rate=fluid.layers.noam_decay(
                args.d_model,
                args.warmup_steps,
                learning_rate=args.learning_rate),
G
guosheng 已提交
136 137 138
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
139 140
            parameter_list=model.parameters()),
        CrossEntropyCriterion(args.label_smooth_eps))
G
guosheng 已提交
141 142 143

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
144
        model.load(args.init_from_checkpoint)
G
guosheng 已提交
145 146
    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
147
        model.load(args.init_from_pretrain_model, reset_optimizer=True)
G
guosheng 已提交
148

G
guosheng 已提交
149
    # model train
150 151 152 153 154 155 156 157 158 159 160 161
    model.fit(train_data=train_loader,
              eval_data=eval_loader,
              epochs=args.epoch,
              eval_freq=1,
              save_freq=1,
              save_dir=args.save_model,
              callbacks=[
                  TrainCallback(
                      args,
                      train_steps_fn=train_steps_fn,
                      eval_steps_fn=eval_steps_fn)
              ])
G
guosheng 已提交
162 163 164 165 166 167 168 169 170 171


if __name__ == "__main__":
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
    check_gpu(args.use_cuda)
    check_version()

    do_train(args)