train.py 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
guosheng 已提交
15
import logging
Y
Yu Yang 已提交
16
import os
G
guosheng 已提交
17
import six
G
guosheng 已提交
18
import sys
Y
Yu Yang 已提交
19
import time
Y
ying 已提交
20

Y
Yu Yang 已提交
21
import numpy as np
22
import paddle
L
Luo Tao 已提交
23
import paddle.fluid as fluid
Y
ying 已提交
24

G
Guo Sheng 已提交
25 26 27 28
import utils.dist_utils as dist_utils
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
29 30 31 32 33

# include task-specific libs
import desc
import reader
from transformer import create_net, position_encoding_init
34

35 36 37
if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
# num_trainers is used for multi-process gpu training
38
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
39 40


41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
def init_from_pretrain_model(args, exe, program):

    assert isinstance(args.init_from_pretrain_model, str)

    if not os.path.exists(args.init_from_pretrain_model):
        raise Warning("The pretrained params do not exist.")
        return False

    def existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        return os.path.exists(
            os.path.join(args.init_from_pretrain_model, var.name))

    fluid.io.load_vars(
        exe,
        args.init_from_pretrain_model,
        main_program=program,
        predicate=existed_params)

    print("finish initing model from pretrained params from %s" %
          (args.init_from_pretrain_model))

    return True


def init_from_checkpoint(args, exe, program):

    assert isinstance(args.init_from_checkpoint, str)

    if not os.path.exists(args.init_from_checkpoint):
        raise Warning("the checkpoint path does not exist.")
        return False

    fluid.io.load_persistables(
        executor=exe,
        dirname=args.init_from_checkpoint,
        main_program=program,
        filename="checkpoint.pdckpt")

    print("finish initing model from checkpoint from %s" %
          (args.init_from_checkpoint))

    return True


def save_checkpoint(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    fluid.io.save_persistables(
        exe,
        os.path.join(checkpoint_dir, dirname),
        main_program=program,
        filename="checkpoint.pdparams")

    print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))

    return True


def save_param(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    param_dir = os.path.join(args.save_model_path, args.save_param)

    if not os.path.exists(param_dir):
        os.mkdir(param_dir)

    fluid.io.save_params(
        exe,
        os.path.join(param_dir, dirname),
        main_program=program,
        filename="params.pdparams")
    print("save parameters at %s" % (os.path.join(param_dir, dirname)))

    return True


def do_train(args):
    if args.use_cuda:
        if num_trainers > 1:  # for multi-process gpu training
            dev_count = 1
130
        else:
131 132 133 134 135 136 137 138 139 140
            dev_count = fluid.core.get_cuda_device_count()
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
    else:
        dev_count = int(os.environ.get('CPU_NUM', 1))
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
Q
Qiao Longfei 已提交
141 142
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
143
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
144
        use_token_batch=args.use_token_batch,
145 146
        batch_size=args.batch_size,
        device_count=dev_count,
Q
Qiao Longfei 已提交
147 148
        pool_size=args.pool_size,
        sort_type=args.sort_type,
149 150
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
151 152 153
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if num_trainers > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    train_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        train_prog.random_seed = random_seed
        startup_prog.random_seed = random_seed

    with fluid.program_guard(train_prog, startup_prog):
171
        with fluid.unique_name.guard():
Q
Qiao Longfei 已提交
172

173
            # define input and reader
174

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
            input_field_names = desc.encoder_data_input_fields + \
                    desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
            input_slots = [{
                "name": name,
                "shape": desc.input_descs[name][0],
                "dtype": desc.input_descs[name][1]
            } for name in input_field_names]

            input_field = InputField(input_slots)
            input_field.build(build_pyreader=True)

            # define the network

            sum_cost, avg_cost, token_num = create_net(
                is_training=True, model_input=input_field, args=args)

            # define the optimizer
C
chengduo 已提交
192

193 194 195 196 197 198 199 200 201 202 203 204 205 206
            with fluid.default_main_program()._lr_schedule_guard():
                learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
                    args.d_model, args.warmup_steps) * args.learning_rate

            optimizer = fluid.optimizer.Adam(
                learning_rate=learning_rate,
                beta1=args.beta1,
                beta2=args.beta2,
                epsilon=float(args.eps))
            optimizer.minimize(avg_cost)

    # prepare training

    ## decorate the pyreader with batch_generator
G
Guo Sheng 已提交
207
    input_field.loader.set_batch_generator(batch_generator)
208 209 210 211

    ## define the executor and program for training

    exe = fluid.Executor(place)
G
fix  
gongweibao 已提交
212

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    exe.run(startup_prog)
    # init position_encoding
    for pos_enc_param_name in desc.pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()

        pos_enc_param.set(
            position_encoding_init(args.max_length + 1, args.d_model), place)

    assert (args.init_from_checkpoint == "") or (
        args.init_from_pretrain_model == "")

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        init_from_checkpoint(args, exe, train_prog)

    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        init_from_pretrain_model(args, exe, train_prog)

    build_strategy = fluid.compiler.BuildStrategy()
    build_strategy.enable_inplace = True
    exec_strategy = fluid.ExecutionStrategy()
    if num_trainers > 1:
237 238 239
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

240
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
241 242
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
243
        exec_strategy=exec_strategy)
Q
Qiao Longfei 已提交
244

G
guosheng 已提交
245
    # the best cross-entropy value with label smoothing
246 247 248 249 250
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
    # start training
G
guosheng 已提交
251

M
minqiyang 已提交
252
    step_idx = 0
253
    for pass_id in range(args.epoch):
Q
Qiao Longfei 已提交
254
        pass_start_time = time.time()
G
Guo Sheng 已提交
255
        input_field.loader.start()
256 257 258 259

        batch_id = 0
        while True:
            try:
260 261 262 263 264
                outs = exe.run(compiled_train_prog,
                               fetch_list=[sum_cost.name, token_num.name]
                               if step_idx % args.print_step == 0 else [])

                if step_idx % args.print_step == 0:
265 266
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
267 268 269 270 271
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
287
                             args.print_step / (time.time() - avg_batch_time)))
G
guosheng 已提交
288 289
                        avg_batch_time = time.time()

290 291 292 293 294 295 296 297 298 299
                if step_idx % args.save_step == 0 and step_idx != 0:

                    if args.save_checkpoint:
                        save_checkpoint(args, exe, train_prog,
                                        "step_" + str(step_idx))

                    if args.save_param:
                        save_param(args, exe, train_prog,
                                   "step_" + str(step_idx))

300 301
                batch_id += 1
                step_idx += 1
302 303

            except fluid.core.EOFException:
G
Guo Sheng 已提交
304
                input_field.loader.reset()
305
                break
G
guosheng 已提交
306 307

        time_consumed = time.time() - pass_start_time
308

309 310 311 312 313
    if args.save_checkpoint:
        save_checkpoint(args, exe, train_prog, "step_final")

    if args.save_param:
        save_param(args, exe, train_prog, "step_final")
314

G
guosheng 已提交
315
    if args.enable_ce:  # For CE
316 317
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
318 319


320
if __name__ == "__main__":
321 322 323
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
G
Guo Sheng 已提交
324 325
    check_gpu(args.use_cuda)
    check_version()
G
fix  
gongweibao 已提交
326

327
    do_train(args)