train.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
guosheng 已提交
15
import logging
Y
Yu Yang 已提交
16
import os
G
guosheng 已提交
17
import six
G
guosheng 已提交
18
import sys
Y
Yu Yang 已提交
19
import time
Y
ying 已提交
20

Y
Yu Yang 已提交
21
import numpy as np
22
import paddle
L
Luo Tao 已提交
23
import paddle.fluid as fluid
H
hysunflower 已提交
24
from paddle.fluid import profiler
Y
ying 已提交
25

G
Guo Sheng 已提交
26 27 28 29
import utils.dist_utils as dist_utils
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
30 31 32 33 34

# include task-specific libs
import desc
import reader
from transformer import create_net, position_encoding_init
35

36 37 38
if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
# num_trainers is used for multi-process gpu training
39
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
40 41


42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
def init_from_pretrain_model(args, exe, program):

    assert isinstance(args.init_from_pretrain_model, str)

    if not os.path.exists(args.init_from_pretrain_model):
        raise Warning("The pretrained params do not exist.")
        return False

    def existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        return os.path.exists(
            os.path.join(args.init_from_pretrain_model, var.name))

    fluid.io.load_vars(
        exe,
        args.init_from_pretrain_model,
        main_program=program,
        predicate=existed_params)

    print("finish initing model from pretrained params from %s" %
          (args.init_from_pretrain_model))

    return True


def init_from_checkpoint(args, exe, program):

    assert isinstance(args.init_from_checkpoint, str)

    if not os.path.exists(args.init_from_checkpoint):
        raise Warning("the checkpoint path does not exist.")
        return False

    fluid.io.load_persistables(
        executor=exe,
        dirname=args.init_from_checkpoint,
        main_program=program,
        filename="checkpoint.pdckpt")

    print("finish initing model from checkpoint from %s" %
          (args.init_from_checkpoint))

    return True


def save_checkpoint(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    fluid.io.save_persistables(
        exe,
        os.path.join(checkpoint_dir, dirname),
        main_program=program,
        filename="checkpoint.pdparams")

    print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))

    return True


def save_param(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    param_dir = os.path.join(args.save_model_path, args.save_param)

    if not os.path.exists(param_dir):
        os.mkdir(param_dir)

    fluid.io.save_params(
        exe,
        os.path.join(param_dir, dirname),
        main_program=program,
        filename="params.pdparams")
    print("save parameters at %s" % (os.path.join(param_dir, dirname)))

    return True


def do_train(args):
    if args.use_cuda:
        if num_trainers > 1:  # for multi-process gpu training
            dev_count = 1
131
        else:
132 133 134 135 136 137 138 139 140 141
            dev_count = fluid.core.get_cuda_device_count()
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
    else:
        dev_count = int(os.environ.get('CPU_NUM', 1))
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
Q
Qiao Longfei 已提交
142 143
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
144
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
145
        use_token_batch=args.use_token_batch,
146 147
        batch_size=args.batch_size,
        device_count=dev_count,
Q
Qiao Longfei 已提交
148 149
        pool_size=args.pool_size,
        sort_type=args.sort_type,
150 151
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
152 153 154
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if num_trainers > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    train_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        train_prog.random_seed = random_seed
        startup_prog.random_seed = random_seed

    with fluid.program_guard(train_prog, startup_prog):
172
        with fluid.unique_name.guard():
Q
Qiao Longfei 已提交
173

174
            # define input and reader
175

176 177
            input_field_names = desc.encoder_data_input_fields + \
                    desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
178
            input_descs = desc.get_input_descs(args.args)
179 180
            input_slots = [{
                "name": name,
181 182
                "shape": input_descs[name][0],
                "dtype": input_descs[name][1]
183 184 185 186 187 188 189 190 191 192 193
            } for name in input_field_names]

            input_field = InputField(input_slots)
            input_field.build(build_pyreader=True)

            # define the network

            sum_cost, avg_cost, token_num = create_net(
                is_training=True, model_input=input_field, args=args)

            # define the optimizer
C
chengduo 已提交
194

195 196 197 198 199 200 201 202 203 204 205 206 207 208
            with fluid.default_main_program()._lr_schedule_guard():
                learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
                    args.d_model, args.warmup_steps) * args.learning_rate

            optimizer = fluid.optimizer.Adam(
                learning_rate=learning_rate,
                beta1=args.beta1,
                beta2=args.beta2,
                epsilon=float(args.eps))
            optimizer.minimize(avg_cost)

    # prepare training

    ## decorate the pyreader with batch_generator
G
Guo Sheng 已提交
209
    input_field.loader.set_batch_generator(batch_generator)
210 211 212 213

    ## define the executor and program for training

    exe = fluid.Executor(place)
G
fix  
gongweibao 已提交
214

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    exe.run(startup_prog)
    # init position_encoding
    for pos_enc_param_name in desc.pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()

        pos_enc_param.set(
            position_encoding_init(args.max_length + 1, args.d_model), place)

    assert (args.init_from_checkpoint == "") or (
        args.init_from_pretrain_model == "")

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        init_from_checkpoint(args, exe, train_prog)

    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        init_from_pretrain_model(args, exe, train_prog)

    build_strategy = fluid.compiler.BuildStrategy()
    build_strategy.enable_inplace = True
    exec_strategy = fluid.ExecutionStrategy()
    if num_trainers > 1:
239 240 241
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

242
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
243 244
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
245
        exec_strategy=exec_strategy)
Q
Qiao Longfei 已提交
246

G
guosheng 已提交
247
    # the best cross-entropy value with label smoothing
248 249 250 251 252
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
    # start training
G
guosheng 已提交
253

M
minqiyang 已提交
254
    step_idx = 0
H
hysunflower 已提交
255
    total_batch_num = 0  # this is for benchmark
256
    for pass_id in range(args.epoch):
Q
Qiao Longfei 已提交
257
        pass_start_time = time.time()
G
Guo Sheng 已提交
258
        input_field.loader.start()
259 260 261

        batch_id = 0
        while True:
H
hysunflower 已提交
262 263
            if args.max_iter and total_batch_num == args.max_iter: # this for benchmark
                return
264
            try:
265 266 267 268 269
                outs = exe.run(compiled_train_prog,
                               fetch_list=[sum_cost.name, token_num.name]
                               if step_idx % args.print_step == 0 else [])

                if step_idx % args.print_step == 0:
270 271
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
272 273 274 275 276
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
292
                             args.print_step / (time.time() - avg_batch_time)))
G
guosheng 已提交
293 294
                        avg_batch_time = time.time()

295 296 297 298 299 300 301 302 303 304
                if step_idx % args.save_step == 0 and step_idx != 0:

                    if args.save_checkpoint:
                        save_checkpoint(args, exe, train_prog,
                                        "step_" + str(step_idx))

                    if args.save_param:
                        save_param(args, exe, train_prog,
                                   "step_" + str(step_idx))

305 306
                batch_id += 1
                step_idx += 1
H
hysunflower 已提交
307 308 309 310 311 312 313 314
                total_batch_num = total_batch_num + 1 # this is for benchmark

                # profiler tools for benchmark
                if args.is_profiler and pass_id == 0 and batch_id == args.print_step:
                    profiler.start_profiler("All")
                elif args.is_profiler and pass_id == 0 and batch_id == args.print_step + 5:
                    profiler.stop_profiler("total", args.profiler_path)
                    return
315 316

            except fluid.core.EOFException:
G
Guo Sheng 已提交
317
                input_field.loader.reset()
318
                break
G
guosheng 已提交
319 320

        time_consumed = time.time() - pass_start_time
321

322 323 324 325 326
    if args.save_checkpoint:
        save_checkpoint(args, exe, train_prog, "step_final")

    if args.save_param:
        save_param(args, exe, train_prog, "step_final")
327

G
guosheng 已提交
328
    if args.enable_ce:  # For CE
329 330
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
331 332


333
if __name__ == "__main__":
334 335 336
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
G
Guo Sheng 已提交
337 338
    check_gpu(args.use_cuda)
    check_version()
G
fix  
gongweibao 已提交
339

340
    do_train(args)