train.py 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
guosheng 已提交
15
import logging
Y
Yu Yang 已提交
16
import os
G
guosheng 已提交
17
import six
G
guosheng 已提交
18
import sys
Y
Yu Yang 已提交
19
import time
Y
ying 已提交
20

Y
Yu Yang 已提交
21
import numpy as np
22
import paddle
L
Luo Tao 已提交
23
import paddle.fluid as fluid
Y
ying 已提交
24

25 26 27 28 29 30
#include palm for easier nlp coding
from palm.toolkit.input_field import InputField
from palm.toolkit.configure import PDConfig

# include task-specific libs
import desc
31
import dist_utils
32 33
import reader
from transformer import create_net, position_encoding_init
34

35 36 37
if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
# num_trainers is used for multi-process gpu training
38
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
39 40


41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
def init_from_pretrain_model(args, exe, program):

    assert isinstance(args.init_from_pretrain_model, str)

    if not os.path.exists(args.init_from_pretrain_model):
        raise Warning("The pretrained params do not exist.")
        return False

    def existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        return os.path.exists(
            os.path.join(args.init_from_pretrain_model, var.name))

    fluid.io.load_vars(
        exe,
        args.init_from_pretrain_model,
        main_program=program,
        predicate=existed_params)

    print("finish initing model from pretrained params from %s" %
          (args.init_from_pretrain_model))

    return True


def init_from_checkpoint(args, exe, program):

    assert isinstance(args.init_from_checkpoint, str)

    if not os.path.exists(args.init_from_checkpoint):
        raise Warning("the checkpoint path does not exist.")
        return False

    fluid.io.load_persistables(
        executor=exe,
        dirname=args.init_from_checkpoint,
        main_program=program,
        filename="checkpoint.pdckpt")

    print("finish initing model from checkpoint from %s" %
          (args.init_from_checkpoint))

    return True


def save_checkpoint(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    fluid.io.save_persistables(
        exe,
        os.path.join(checkpoint_dir, dirname),
        main_program=program,
        filename="checkpoint.pdparams")

    print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))

    return True


def save_param(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    param_dir = os.path.join(args.save_model_path, args.save_param)

    if not os.path.exists(param_dir):
        os.mkdir(param_dir)

    fluid.io.save_params(
        exe,
        os.path.join(param_dir, dirname),
        main_program=program,
        filename="params.pdparams")
    print("save parameters at %s" % (os.path.join(param_dir, dirname)))

    return True


def do_train(args):
    if args.use_cuda:
        if num_trainers > 1:  # for multi-process gpu training
            dev_count = 1
130
        else:
131 132 133 134 135 136 137 138 139 140
            dev_count = fluid.core.get_cuda_device_count()
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
    else:
        dev_count = int(os.environ.get('CPU_NUM', 1))
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
Q
Qiao Longfei 已提交
141 142
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
143
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
144
        use_token_batch=args.use_token_batch,
145 146
        batch_size=args.batch_size,
        device_count=dev_count,
Q
Qiao Longfei 已提交
147 148
        pool_size=args.pool_size,
        sort_type=args.sort_type,
149 150
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
151 152 153
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if num_trainers > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    train_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        train_prog.random_seed = random_seed
        startup_prog.random_seed = random_seed

    with fluid.program_guard(train_prog, startup_prog):
171
        with fluid.unique_name.guard():
Q
Qiao Longfei 已提交
172

173
            # define input and reader
174

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
            input_field_names = desc.encoder_data_input_fields + \
                    desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
            input_slots = [{
                "name": name,
                "shape": desc.input_descs[name][0],
                "dtype": desc.input_descs[name][1]
            } for name in input_field_names]

            input_field = InputField(input_slots)
            input_field.build(build_pyreader=True)

            # define the network

            sum_cost, avg_cost, token_num = create_net(
                is_training=True, model_input=input_field, args=args)

            sum_cost.persistable = avg_cost.persistable = token_num.persistable = True

            # define the optimizer
C
chengduo 已提交
194

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
            with fluid.default_main_program()._lr_schedule_guard():
                learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
                    args.d_model, args.warmup_steps) * args.learning_rate

            optimizer = fluid.optimizer.Adam(
                learning_rate=learning_rate,
                beta1=args.beta1,
                beta2=args.beta2,
                epsilon=float(args.eps))
            optimizer.minimize(avg_cost)

    # prepare training

    ## decorate the pyreader with batch_generator
    input_field.reader.decorate_batch_generator(batch_generator)

    ## define the executor and program for training

    exe = fluid.Executor(place)
G
fix  
gongweibao 已提交
214

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    exe.run(startup_prog)
    # init position_encoding
    for pos_enc_param_name in desc.pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()

        pos_enc_param.set(
            position_encoding_init(args.max_length + 1, args.d_model), place)

    assert (args.init_from_checkpoint == "") or (
        args.init_from_pretrain_model == "")

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        init_from_checkpoint(args, exe, train_prog)

    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        init_from_pretrain_model(args, exe, train_prog)

    build_strategy = fluid.compiler.BuildStrategy()
    exec_strategy = fluid.ExecutionStrategy()
    if num_trainers > 1:
238 239 240
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

241
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
242 243
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
244
        exec_strategy=exec_strategy)
Q
Qiao Longfei 已提交
245

G
guosheng 已提交
246
    # the best cross-entropy value with label smoothing
247 248 249 250 251
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
    # start training
G
guosheng 已提交
252

M
minqiyang 已提交
253
    step_idx = 0
254
    for pass_id in range(args.epoch):
Q
Qiao Longfei 已提交
255
        pass_start_time = time.time()
256
        input_field.reader.start()
257 258 259 260

        batch_id = 0
        while True:
            try:
261 262 263 264 265
                outs = exe.run(compiled_train_prog,
                               fetch_list=[sum_cost.name, token_num.name]
                               if step_idx % args.print_step == 0 else [])

                if step_idx % args.print_step == 0:
266 267
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
268 269 270 271 272
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
288
                             args.print_step / (time.time() - avg_batch_time)))
G
guosheng 已提交
289 290
                        avg_batch_time = time.time()

291 292 293 294 295 296 297 298 299 300
                if step_idx % args.save_step == 0 and step_idx != 0:

                    if args.save_checkpoint:
                        save_checkpoint(args, exe, train_prog,
                                        "step_" + str(step_idx))

                    if args.save_param:
                        save_param(args, exe, train_prog,
                                   "step_" + str(step_idx))

301 302
                batch_id += 1
                step_idx += 1
303 304 305

            except fluid.core.EOFException:
                input_field.reader.reset()
306
                break
G
guosheng 已提交
307 308

        time_consumed = time.time() - pass_start_time
309

310 311 312 313 314
    if args.save_checkpoint:
        save_checkpoint(args, exe, train_prog, "step_final")

    if args.save_param:
        save_param(args, exe, train_prog, "step_final")
315

G
guosheng 已提交
316
    if args.enable_ce:  # For CE
317 318
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
319 320


321
if __name__ == "__main__":
322 323 324
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
G
fix  
gongweibao 已提交
325

326
    do_train(args)