train.py 7.6 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
J
Jiabin Yang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
JiabinYang 已提交
15
from __future__ import print_function
16 17
import argparse
import ast
18

J
JiabinYang 已提交
19 20
import numpy as np
import paddle
21
import paddle.fluid as fluid
J
JiabinYang 已提交
22 23
import paddle.dataset.wmt16 as wmt16

24 25 26 27
from model import TransFormer, NoamDecay
from config import *
from data_util import *

J
JiabinYang 已提交
28

29
def parse_args():
30
    parser = argparse.ArgumentParser("Arguments for Training")
31 32 33 34
    parser.add_argument(
        "--use_data_parallel",
        type=ast.literal_eval,
        default=False,
35 36 37 38 39 40 41 42 43 44 45
        help="The flag indicating whether to use multi-GPU.")
    parser.add_argument(
        "--model_file",
        type=str,
        default="transformer_params",
        help="Save the model as a file named `model_file.pdparams`.")
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
46
    args = parser.parse_args()
47
    merge_cfg_from_list(args.opts, [TrainTaskConfig, ModelHyperParams])
48 49 50
    return args


51
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
J
JiabinYang 已提交
52
    """
53
    inputs for training
J
JiabinYang 已提交
54
    """
55
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
J
JiabinYang 已提交
56 57 58
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
59
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
J
JiabinYang 已提交
60 61 62 63 64 65 66
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")

67
    lbl_word, lbl_weight, num_token = pad_batch_data(
J
JiabinYang 已提交
68 69 70 71 72 73 74 75 76
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
        return_max_len=False,
        return_num_token=True)

77 78
    data_inputs = [
        src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
J
JiabinYang 已提交
79 80 81
        trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
    ]

82 83 84 85 86
    var_inputs = []
    for i, field in enumerate(encoder_data_input_fields +
                              decoder_data_input_fields[:-1] +
                              label_data_input_fields):
        var_inputs.append(to_variable(data_inputs[i], name=field))
J
JiabinYang 已提交
87

88 89 90 91 92 93
    enc_inputs = var_inputs[0:len(encoder_data_input_fields)]
    dec_inputs = var_inputs[len(encoder_data_input_fields
                                ):len(encoder_data_input_fields) +
                            len(decoder_data_input_fields[:-1])]
    label = var_inputs[-2]
    weights = var_inputs[-1]
J
JiabinYang 已提交
94

95
    return enc_inputs, dec_inputs, label, weights
J
JiabinYang 已提交
96 97


98
def train(args):
J
JiabinYang 已提交
99 100 101 102 103
    """
    train models
    :return:
    """

104 105 106 107 108 109 110
    trainer_count = fluid.dygraph.parallel.Env().nranks
    place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
        if args.use_data_parallel else fluid.CUDAPlace(0)
    with fluid.dygraph.guard(place):
        if args.use_data_parallel:
            strategy = fluid.dygraph.parallel.prepare_context()

111
        # define model
J
JiabinYang 已提交
112 113 114 115 116 117 118 119 120 121
        transformer = TransFormer(
            'transformer', ModelHyperParams.src_vocab_size,
            ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
            ModelHyperParams.n_layer, ModelHyperParams.n_head,
            ModelHyperParams.d_key, ModelHyperParams.d_value,
            ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
            ModelHyperParams.prepostprocess_dropout,
            ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
            ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
            ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
122 123 124 125 126 127 128 129
        # define optimizer
        optimizer = fluid.optimizer.Adam(learning_rate=NoamDecay(
            ModelHyperParams.d_model, TrainTaskConfig.warmup_steps,
            TrainTaskConfig.learning_rate),
                                         beta1=TrainTaskConfig.beta1,
                                         beta2=TrainTaskConfig.beta2,
                                         epsilon=TrainTaskConfig.eps)
        #
130
        if args.use_data_parallel:
131 132
            transformer = fluid.dygraph.parallel.DataParallel(
                transformer, strategy)
133

134 135 136 137
        # define data generator for training and validation
        train_reader = paddle.batch(wmt16.train(
            ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size),
                                    batch_size=TrainTaskConfig.batch_size)
138
        if args.use_data_parallel:
139 140 141 142 143 144 145 146
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
        val_reader = paddle.batch(wmt16.test(ModelHyperParams.src_vocab_size,
                                             ModelHyperParams.trg_vocab_size),
                                  batch_size=TrainTaskConfig.batch_size)

        # loop for training iterations
        for i in range(TrainTaskConfig.pass_num):
J
JiabinYang 已提交
147
            dy_step = 0
148 149 150 151 152 153
            sum_cost = 0
            transformer.train()
            for batch in train_reader():
                enc_inputs, dec_inputs, label, weights = prepare_train_input(
                    batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head)
J
JiabinYang 已提交
154 155 156

                dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
                    enc_inputs, dec_inputs, label, weights)
157 158 159 160 161 162 163

                if args.use_data_parallel:
                    dy_avg_cost = transformer.scale_loss(dy_avg_cost)
                    dy_avg_cost.backward()
                    transformer.apply_collective_grads()
                else:
                    dy_avg_cost.backward()
J
JiabinYang 已提交
164 165
                optimizer.minimize(dy_avg_cost)
                transformer.clear_gradients()
166

J
JiabinYang 已提交
167 168 169
                dy_step = dy_step + 1
                if dy_step % 10 == 0:
                    print("pass num : {}, batch_id: {}, dy_graph avg loss: {}".
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
                          format(i, dy_step,
                                 dy_avg_cost.numpy() * trainer_count))

            # switch to evaluation mode
            transformer.eval()
            sum_cost = 0
            token_num = 0
            for batch in val_reader():
                enc_inputs, dec_inputs, label, weights = prepare_train_input(
                    batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head)

                dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
                    enc_inputs, dec_inputs, label, weights)
                sum_cost += dy_sum_cost.numpy()
                token_num += dy_token_num.numpy()
            print("pass : {} finished, validation avg loss: {}".format(
                i, sum_cost / token_num))

        if fluid.dygraph.parallel.Env().dev_id == 0:
            fluid.save_dygraph(transformer.state_dict(), args.model_file)
J
JiabinYang 已提交
191 192 193


if __name__ == '__main__':
194 195
    args = parse_args()
    train(args)