train.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
guosheng 已提交
15
import logging
Y
Yu Yang 已提交
16
import os
G
guosheng 已提交
17
import six
G
guosheng 已提交
18
import sys
Y
Yu Yang 已提交
19
import time
Y
ying 已提交
20

Y
Yu Yang 已提交
21
import numpy as np
22
import paddle
L
Luo Tao 已提交
23
import paddle.fluid as fluid
H
hysunflower 已提交
24
from paddle.fluid import profiler
Y
ying 已提交
25

G
Guo Sheng 已提交
26 27 28 29
import utils.dist_utils as dist_utils
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
30 31 32 33 34

# include task-specific libs
import desc
import reader
from transformer import create_net, position_encoding_init
35

36 37 38
if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
# num_trainers is used for multi-process gpu training
39
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
40 41


42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
def init_from_pretrain_model(args, exe, program):

    assert isinstance(args.init_from_pretrain_model, str)

    if not os.path.exists(args.init_from_pretrain_model):
        raise Warning("The pretrained params do not exist.")
        return False

    def existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        return os.path.exists(
            os.path.join(args.init_from_pretrain_model, var.name))

    fluid.io.load_vars(
        exe,
        args.init_from_pretrain_model,
        main_program=program,
        predicate=existed_params)

    print("finish initing model from pretrained params from %s" %
          (args.init_from_pretrain_model))

    return True


def init_from_checkpoint(args, exe, program):

    assert isinstance(args.init_from_checkpoint, str)

    if not os.path.exists(args.init_from_checkpoint):
        raise Warning("the checkpoint path does not exist.")
        return False

    fluid.io.load_persistables(
        executor=exe,
        dirname=args.init_from_checkpoint,
        main_program=program,
        filename="checkpoint.pdckpt")

    print("finish initing model from checkpoint from %s" %
          (args.init_from_checkpoint))

    return True


def save_checkpoint(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint)

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    fluid.io.save_persistables(
        exe,
        os.path.join(checkpoint_dir, dirname),
        main_program=program,
        filename="checkpoint.pdparams")

    print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))

    return True


def save_param(args, exe, program, dirname):

    assert isinstance(args.save_model_path, str)

    param_dir = os.path.join(args.save_model_path, args.save_param)

    if not os.path.exists(param_dir):
        os.mkdir(param_dir)

    fluid.io.save_params(
        exe,
        os.path.join(param_dir, dirname),
        main_program=program,
        filename="params.pdparams")
    print("save parameters at %s" % (os.path.join(param_dir, dirname)))

    return True


def do_train(args):
    if args.use_cuda:
        if num_trainers > 1:  # for multi-process gpu training
            dev_count = 1
131
        else:
132 133 134 135 136 137 138 139 140 141
            dev_count = fluid.core.get_cuda_device_count()
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
    else:
        dev_count = int(os.environ.get('CPU_NUM', 1))
        place = fluid.CPUPlace()

    # define the data generator
    processor = reader.DataProcessor(
        fpattern=args.training_file,
Q
Qiao Longfei 已提交
142 143
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
144
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
145
        use_token_batch=args.use_token_batch,
146 147
        batch_size=args.batch_size,
        device_count=dev_count,
Q
Qiao Longfei 已提交
148 149
        pool_size=args.pool_size,
        sort_type=args.sort_type,
150 151
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
152 153 154
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        max_length=args.max_length,
        n_head=args.n_head)
    batch_generator = processor.data_generator(phase="train")
    if num_trainers > 1:  # for multi-process gpu training
        batch_generator = fluid.contrib.reader.distributed_batch_reader(
            batch_generator)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    train_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        train_prog.random_seed = random_seed
        startup_prog.random_seed = random_seed

    with fluid.program_guard(train_prog, startup_prog):
172
        with fluid.unique_name.guard():
Q
Qiao Longfei 已提交
173

174
            # define input and reader
175

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
            input_field_names = desc.encoder_data_input_fields + \
                    desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
            input_slots = [{
                "name": name,
                "shape": desc.input_descs[name][0],
                "dtype": desc.input_descs[name][1]
            } for name in input_field_names]

            input_field = InputField(input_slots)
            input_field.build(build_pyreader=True)

            # define the network

            sum_cost, avg_cost, token_num = create_net(
                is_training=True, model_input=input_field, args=args)

            # define the optimizer
C
chengduo 已提交
193

194 195 196 197 198 199 200 201 202 203 204 205 206 207
            with fluid.default_main_program()._lr_schedule_guard():
                learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
                    args.d_model, args.warmup_steps) * args.learning_rate

            optimizer = fluid.optimizer.Adam(
                learning_rate=learning_rate,
                beta1=args.beta1,
                beta2=args.beta2,
                epsilon=float(args.eps))
            optimizer.minimize(avg_cost)

    # prepare training

    ## decorate the pyreader with batch_generator
G
Guo Sheng 已提交
208
    input_field.loader.set_batch_generator(batch_generator)
209 210 211 212

    ## define the executor and program for training

    exe = fluid.Executor(place)
G
fix  
gongweibao 已提交
213

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    exe.run(startup_prog)
    # init position_encoding
    for pos_enc_param_name in desc.pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()

        pos_enc_param.set(
            position_encoding_init(args.max_length + 1, args.d_model), place)

    assert (args.init_from_checkpoint == "") or (
        args.init_from_pretrain_model == "")

    ## init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        init_from_checkpoint(args, exe, train_prog)

    ## init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        init_from_pretrain_model(args, exe, train_prog)

    build_strategy = fluid.compiler.BuildStrategy()
    build_strategy.enable_inplace = True
    exec_strategy = fluid.ExecutionStrategy()
    if num_trainers > 1:
238 239 240
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

241
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
242 243
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
244
        exec_strategy=exec_strategy)
Q
Qiao Longfei 已提交
245

G
guosheng 已提交
246
    # the best cross-entropy value with label smoothing
247 248 249 250 251
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
    # start training
G
guosheng 已提交
252

M
minqiyang 已提交
253
    step_idx = 0
H
hysunflower 已提交
254
    total_batch_num = 0  # this is for benchmark
255
    for pass_id in range(args.epoch):
Q
Qiao Longfei 已提交
256
        pass_start_time = time.time()
G
Guo Sheng 已提交
257
        input_field.loader.start()
258 259 260

        batch_id = 0
        while True:
H
hysunflower 已提交
261 262
            if args.max_iter and total_batch_num == args.max_iter: # this for benchmark
                return
263
            try:
264 265 266 267 268
                outs = exe.run(compiled_train_prog,
                               fetch_list=[sum_cost.name, token_num.name]
                               if step_idx % args.print_step == 0 else [])

                if step_idx % args.print_step == 0:
269 270
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
271 272 273 274 275
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
291
                             args.print_step / (time.time() - avg_batch_time)))
G
guosheng 已提交
292 293
                        avg_batch_time = time.time()

294 295 296 297 298 299 300 301 302 303
                if step_idx % args.save_step == 0 and step_idx != 0:

                    if args.save_checkpoint:
                        save_checkpoint(args, exe, train_prog,
                                        "step_" + str(step_idx))

                    if args.save_param:
                        save_param(args, exe, train_prog,
                                   "step_" + str(step_idx))

304 305
                batch_id += 1
                step_idx += 1
H
hysunflower 已提交
306 307 308 309 310 311 312 313
                total_batch_num = total_batch_num + 1 # this is for benchmark

                # profiler tools for benchmark
                if args.is_profiler and pass_id == 0 and batch_id == args.print_step:
                    profiler.start_profiler("All")
                elif args.is_profiler and pass_id == 0 and batch_id == args.print_step + 5:
                    profiler.stop_profiler("total", args.profiler_path)
                    return
314 315

            except fluid.core.EOFException:
G
Guo Sheng 已提交
316
                input_field.loader.reset()
317
                break
G
guosheng 已提交
318 319

        time_consumed = time.time() - pass_start_time
320

321 322 323 324 325
    if args.save_checkpoint:
        save_checkpoint(args, exe, train_prog, "step_final")

    if args.save_param:
        save_param(args, exe, train_prog, "step_final")
326

G
guosheng 已提交
327
    if args.enable_ce:  # For CE
328 329
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
330 331


332
if __name__ == "__main__":
333 334 335
    args = PDConfig(yaml_file="./transformer.yaml")
    args.build()
    args.Print()
G
Guo Sheng 已提交
336 337
    check_gpu(args.use_cuda)
    check_version()
G
fix  
gongweibao 已提交
338

339
    do_train(args)