train.py 12.0 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
J
Jiabin Yang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19
import logging
import os
import six
import sys
import time
20

J
JiabinYang 已提交
21 22
import numpy as np
import paddle
23
import paddle.fluid as fluid
24 25 26 27 28 29 30 31

from utils.configure import PDConfig
from utils.check import check_gpu, check_version

# include task-specific libs
import reader
from model import Transformer, CrossEntropyCriterion, NoamDecay

32 33 34 35
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

36 37 38 39

def do_train(args):
    if args.use_cuda:
        trainer_count = fluid.dygraph.parallel.Env().nranks
40 41
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
        ).dev_id) if trainer_count > 1 else fluid.CUDAPlace(0)
42 43 44 45 46
    else:
        trainer_count = 1
        place = fluid.CPUPlace()

    # define the data generator
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    processor = reader.DataProcessor(
        fpattern=args.training_file,
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        token_delimiter=args.token_delimiter,
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size,
        device_count=trainer_count,
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        max_length=args.max_length,
        n_head=args.n_head)
64
    batch_generator = processor.data_generator(phase="train")
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    if args.validation_file:
        val_processor = reader.DataProcessor(
            fpattern=args.validation_file,
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            token_delimiter=args.token_delimiter,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size,
            device_count=trainer_count,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=False,
            shuffle_batch=False,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2],
            max_length=args.max_length,
            n_head=args.n_head)
        val_batch_generator = val_processor.data_generator(phase="train")
84 85 86 87 88 89
    if trainer_count > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

90
    with fluid.dygraph.guard(place):
91 92 93 94 95 96 97 98 99
        # set seed for CE
        random_seed = eval(str(args.random_seed))
        if random_seed is not None:
            fluid.default_main_program().random_seed = random_seed
            fluid.default_startup_program().random_seed = random_seed

        # define data loader
        train_loader = fluid.io.DataLoader.from_generator(capacity=10)
        train_loader.set_batch_generator(batch_generator, places=place)
100 101 102
        if args.validation_file:
            val_loader = fluid.io.DataLoader.from_generator(capacity=10)
            val_loader.set_batch_generator(val_batch_generator, places=place)
103

104
        # define model
105 106 107 108 109 110 111 112 113 114 115
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps)

116
        # define optimizer
117 118 119 120 121 122 123 124 125 126
        optimizer = fluid.optimizer.Adam(
            learning_rate=NoamDecay(args.d_model, args.warmup_steps,
                                    args.learning_rate),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameter_list=transformer.parameters())

        ## init from some checkpoint, to resume the previous training
        if args.init_from_checkpoint:
127
            model_dict, opt_dict = fluid.load_dygraph(
128 129 130 131 132
                os.path.join(args.init_from_checkpoint, "transformer"))
            transformer.load_dict(model_dict)
            optimizer.set_dict(opt_dict)
        ## init from some pretrain models, to better solve the current task
        if args.init_from_pretrain_model:
133
            model_dict, _ = fluid.load_dygraph(
134 135 136 137 138
                os.path.join(args.init_from_pretrain_model, "transformer"))
            transformer.load_dict(model_dict)

        if trainer_count > 1:
            strategy = fluid.dygraph.parallel.prepare_context()
139 140
            transformer = fluid.dygraph.parallel.DataParallel(transformer,
                                                              strategy)
141

142 143 144
        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
            (1. - args.label_smooth_eps) * np.log(
145 146
                (1. - args.label_smooth_eps)) + args.label_smooth_eps *
            np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
147

Z
zhengya01 已提交
148 149
        ce_time = []
        ce_ppl = []
150
        step_idx = 0
151

152 153
        # train loop
        for pass_id in range(args.epoch):
154 155
            epoch_start = time.time()

156
            batch_id = 0
157
            batch_start = time.time()
H
hong 已提交
158
            interval_word_num = 0.0
159
            for input_data in train_loader():
160
                if args.max_iter and step_idx == args.max_iter:  #NOTE: used for benchmark
161
                    return
162 163
                batch_reader_end = time.time()

164 165 166
                (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
                 trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
                 lbl_weight) = input_data
H
hong 已提交
167

168 169 170 171
                logits = transformer(src_word, src_pos, src_slf_attn_bias,
                                     trg_word, trg_pos, trg_slf_attn_bias,
                                     trg_src_attn_bias)

172 173
                sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
                                                          lbl_weight)
174 175 176 177

                if trainer_count > 1:
                    avg_cost = transformer.scale_loss(avg_cost)
                    avg_cost.backward()
178 179
                    transformer.apply_collective_grads()
                else:
180 181 182
                    avg_cost.backward()

                optimizer.minimize(avg_cost)
J
JiabinYang 已提交
183
                transformer.clear_gradients()
184

H
hong 已提交
185
                interval_word_num += np.prod(src_word.shape)
186 187 188 189
                if step_idx % args.print_step == 0:
                    total_avg_cost = avg_cost.numpy() * trainer_count

                    if step_idx == 0:
190
                        logger.info(
191 192 193
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
194 195
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
196
                    else:
197 198
                        train_avg_batch_cost = args.print_step / (
                            time.time() - batch_start)
H
hong 已提交
199 200
                        word_speed = interval_word_num / (
                            time.time() - batch_start)
201
                        logger.info(
202
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
H
hong 已提交
203
                            "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
L
liu zhengxi 已提交
204
                            "words speed: %0.2f words/s" %
H
hong 已提交
205 206 207 208
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             train_avg_batch_cost, word_speed))
209
                    batch_start = time.time()
H
hong 已提交
210
                    interval_word_num = 0.0
211 212 213 214 215 216 217 218 219 220 221

                if step_idx % args.save_step == 0 and step_idx != 0:
                    # validation
                    if args.validation_file:
                        transformer.eval()
                        total_sum_cost = 0
                        total_token_num = 0
                        for input_data in val_loader():
                            (src_word, src_pos, src_slf_attn_bias, trg_word,
                             trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
                             lbl_word, lbl_weight) = input_data
222 223 224
                            logits = transformer(
                                src_word, src_pos, src_slf_attn_bias, trg_word,
                                trg_pos, trg_slf_attn_bias, trg_src_attn_bias)
225 226 227 228 229
                            sum_cost, avg_cost, token_num = criterion(
                                logits, lbl_word, lbl_weight)
                            total_sum_cost += sum_cost.numpy()
                            total_token_num += token_num.numpy()
                            total_avg_cost = total_sum_cost / total_token_num
230 231 232 233 234
                        logger.info("validation, step_idx: %d, avg loss: %f, "
                                    "normalized loss: %f, ppl: %f" %
                                    (step_idx, total_avg_cost,
                                     total_avg_cost - loss_normalizer,
                                     np.exp([min(total_avg_cost, 100)])))
235 236 237
                        transformer.train()

                    if args.save_model and (
238 239
                            trainer_count == 1 or
                            fluid.dygraph.parallel.Env().dev_id == 0):
240 241 242 243 244 245 246 247 248 249 250 251 252 253
                        model_dir = os.path.join(args.save_model,
                                                 "step_" + str(step_idx))
                        if not os.path.exists(model_dir):
                            os.makedirs(model_dir)
                        fluid.save_dygraph(
                            transformer.state_dict(),
                            os.path.join(model_dir, "transformer"))
                        fluid.save_dygraph(
                            optimizer.state_dict(),
                            os.path.join(model_dir, "transformer"))

                batch_id += 1
                step_idx += 1

254 255
            train_epoch_cost = time.time() - epoch_start
            ce_time.append(train_epoch_cost)
256 257
            logger.info("train epoch: %d, epoch_cost: %.5f s" %
                        (pass_id, train_epoch_cost))
258 259 260 261 262 263 264 265 266 267

        if args.save_model:
            model_dir = os.path.join(args.save_model, "step_final")
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            fluid.save_dygraph(transformer.state_dict(),
                               os.path.join(model_dir, "transformer"))
            fluid.save_dygraph(optimizer.state_dict(),
                               os.path.join(model_dir, "transformer"))

Z
zhengya01 已提交
268 269 270 271 272 273 274 275 276 277 278
        if args.enable_ce:
            _ppl = 0
            _time = 0
            try:
                _time = ce_time[-1]
                _ppl = ce_ppl[-1]
            except:
                print("ce info error")
            print("kpis\ttrain_duration_card%s\t%s" % (trainer_count, _time))
            print("kpis\ttrain_ppl_card%s\t%f" % (trainer_count, _ppl))

279 280 281 282 283 284 285 286 287

if __name__ == "__main__":
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
    check_gpu(args.use_cuda)
    check_version()

    do_train(args)