train.py 17.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT pretraining."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import argparse
import numpy as np
import multiprocessing

import paddle
import paddle.fluid as fluid

from reader.pretraining import DataReader
from model.bert import BertModel, BertConfig
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments
from utils.init import init_checkpoint, init_pretraining_params

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("bert_config_path",      str,  "./config/bert_config.json",  "Path to the json file for bert model config.")
model_g.add_arg("init_checkpoint",       str,  None,                         "Init checkpoint to resume training from.")
model_g.add_arg("checkpoints",           str,  "checkpoints",                "Path to save checkpoints.")
model_g.add_arg("weight_sharing",        bool, True,                         "If set, share weights between word embedding and masked lm.")
model_g.add_arg("generate_neg_sample",   bool, True,                         "If set, randomly generate negtive samples by positive samples.")

train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch",             int,    100,     "Number of epoches for training.")
train_g.add_arg("learning_rate",     float,  0.0001,  "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler",      str,    "linear_warmup_decay",
                "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay",      float,  0.01,    "Weight decay rate for L2 regularizer.")
train_g.add_arg("num_train_steps",   int,    1000000, "Total steps to perform pretraining.")
train_g.add_arg("warmup_steps",      int,    4000,    "Total steps to perform warmup when pretraining.")
train_g.add_arg("save_steps",        int,    10000,   "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps",  int,    1000,    "The steps interval to evaluate model performance.")
train_g.add_arg("use_fp16",          bool,   False,   "Whether to use fp16 mixed precision training.")
train_g.add_arg("loss_scaling",      float,  1.0,
                "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")

log_g = ArgumentGroup(parser,     "logging", "logging related.")
log_g.add_arg("skip_steps",          int,    10,    "The steps interval to print loss.")
log_g.add_arg("verbose",             bool,   False, "Whether to output verbose log.")

data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir",            str,  "./data/train/",       "Path to training data.")
64 65
data_g.add_arg("validation_set_dir",  str,  "./data/validation/",  "Path to validation data.")
data_g.add_arg("test_set_dir",        str,  None,                  "Path to test data.")
Y
Yibing Liu 已提交
66
data_g.add_arg("vocab_path",          str,  "./config/vocab.txt",  "Vocabulary path.")
67 68 69 70 71 72
data_g.add_arg("max_seq_len",         int,  512,                   "Tokens' number of the longest seqence allowed.")
data_g.add_arg("batch_size",          int,  8192,
               "The total number of examples in one batch for training, see also --in_tokens.")
data_g.add_arg("in_tokens",           bool, True,
               "If set, the batch size will be the maximum number of tokens in one batch. "
               "Otherwise, it will be the maximum number of examples in one batch.")
Y
Yibing Liu 已提交
73 74

run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
75 76 77 78 79
run_type_g.add_arg("is_distributed",               bool,   False,  "If set, then start distributed training.")
run_type_g.add_arg("use_cuda",                     bool,   True,   "If set, use GPU for training.")
run_type_g.add_arg("use_fast_executor",            bool,   False,  "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int,    1,      "Ihe iteration intervals to clean up temporary variables.")
run_type_g.add_arg("do_test",                      bool,   False,  "Whether to perform evaluation on test data set.")
Y
Yibing Liu 已提交
80 81 82 83 84 85 86 87 88 89

args = parser.parse_args()
# yapf: enable.


def create_model(pyreader_name, bert_config):
    pyreader = fluid.layers.py_reader(
        capacity=70,
        shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
                [-1, args.max_seq_len, 1],
90 91
                [-1, args.max_seq_len, 1], [-1, 1], [-1, 1],
                [-1, 1]],
Y
Yibing Liu 已提交
92
        dtypes=[
93
            'int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64'
Y
Yibing Liu 已提交
94
        ],
95
        lod_levels=[0, 0, 0, 0, 0, 0, 0],
Y
Yibing Liu 已提交
96 97 98
        name=pyreader_name,
        use_double_buffer=True)

99
    (src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels) = fluid.layers.read_file(pyreader)
Y
Yibing Liu 已提交
100 101 102 103 104

    bert = BertModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
105
        input_mask=input_mask,
Y
Yibing Liu 已提交
106 107 108 109 110
        config=bert_config,
        weight_sharing=args.weight_sharing,
        use_fp16=args.use_fp16)

    next_sent_acc, mask_lm_loss, total_loss = bert.get_pretraining_output(
111
        mask_label, mask_pos, labels)
Y
Yibing Liu 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

    if args.use_fp16 and args.loss_scaling > 1.0:
        total_loss *= args.loss_scaling

    return pyreader, next_sent_acc, mask_lm_loss, total_loss


def predict_wrapper(args,
                    exe,
                    bert_config,
                    test_prog=None,
                    pyreader=None,
                    fetch_list=None):
    # Context to do validation.
    data_path = args.test_set_dir if args.do_test else args.validation_set_dir
    data_reader = DataReader(
        data_path,
        vocab_path=args.vocab_path,
        batch_size=args.batch_size,
131
        in_tokens=args.in_tokens,
Y
Yibing Liu 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        voc_size=bert_config['vocab_size'],
        shuffle_files=False,
        epoch=1,
        max_seq_len=args.max_seq_len,
        is_test=True)

    if args.do_test:
        assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \
                                                  to specify you pretrained model checkpoints"

        init_pretraining_params(exe, args.init_checkpoint, test_prog)

    def predict(exe=exe, pyreader=pyreader):

        pyreader.decorate_tensor_provider(data_reader.data_generator())
        pyreader.start()

        cost = 0
        lm_cost = 0
        acc = 0
        steps = 0
        time_begin = time.time()
        while True:
            try:
                each_next_acc, each_mask_lm_cost, each_total_cost = exe.run(
                    fetch_list=fetch_list, program=test_prog)
                acc += each_next_acc
                lm_cost += each_mask_lm_cost
                cost += each_total_cost
                steps += 1
                if args.do_test and steps % args.skip_steps == 0:
                    print("[test_set] steps: %d" % steps)

            except fluid.core.EOFException:
                pyreader.reset()
                break

        used_time = time.time() - time_begin
        return cost, lm_cost, acc, steps, (args.skip_steps / used_time)

    return predict


def test(args):
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    test_prog = fluid.Program()
    test_startup = fluid.Program()
    with fluid.program_guard(test_prog, test_startup):
        with fluid.unique_name.guard():
            test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
                pyreader_name='test_reader', bert_config=bert_config)

    test_prog = test_prog.clone(for_test=True)

    place = fluid.CUDAPlace(0) if args.use_cuda == True else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(test_startup)

    predict = predict_wrapper(
        args,
        exe,
        bert_config,
        test_prog=test_prog,
        pyreader=test_pyreader,
        fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name])

    print("test begin")
    loss, lm_loss, acc, steps, speed = predict()
    print(
        "[test_set] loss: %f, global ppl: %f, next_sent_acc: %f, speed: %f steps/s"
        % (np.mean(np.array(loss) / steps),
           np.exp(np.mean(np.array(lm_loss) / steps)),
           np.mean(np.array(acc) / steps), speed))


def train(args):
    print("pretraining start")
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    train_program = fluid.Program()
    startup_prog = fluid.Program()
    with fluid.program_guard(train_program, startup_prog):
        with fluid.unique_name.guard():
            train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
                pyreader_name='train_reader', bert_config=bert_config)
            scheduled_lr = optimization(
                loss=total_loss,
                warmup_steps=args.warmup_steps,
                num_train_steps=args.num_train_steps,
                learning_rate=args.learning_rate,
                train_program=train_program,
                startup_prog=startup_prog,
                weight_decay=args.weight_decay,
                scheduler=args.lr_scheduler,
                use_fp16=args.use_fp16,
                loss_scaling=args.loss_scaling)

            fluid.memory_optimize(
                input_program=train_program,
                skip_opt_set=[
                    next_sent_acc.name, mask_lm_loss.name, total_loss.name
                ])

    test_prog = fluid.Program()
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
                pyreader_name='test_reader', bert_config=bert_config)

    test_prog = test_prog.clone(for_test=True)

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))

    print("Device count %d" % dev_count)
254 255 256 257 258 259 260 261 262 263
    if args.verbose:
        if args.in_tokens:
            lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
         program=train_program,
         batch_size=args.batch_size // args.max_seq_len)
        else:
            lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
         program=train_program, batch_size=args.batch_size)
        print("Theoretical memory usage in training: %.3f - %.3f %s" %
              (lower_mem, upper_mem, unit))
Y
Yibing Liu 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303

    nccl2_num_trainers = 1
    nccl2_trainer_id = 0
    print("args.is_distributed:", args.is_distributed)
    if args.is_distributed:
        worker_endpoints_env = os.getenv("worker_endpoints")
        worker_endpoints = worker_endpoints_env.split(",")
        trainers_num = len(worker_endpoints)
        current_endpoint = os.getenv("current_endpoint")
        trainer_id = worker_endpoints.index(current_endpoint)
        if trainer_id == 0:
            print("train_id == 0, sleep 60s")
            time.sleep(60)
        print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
              trainer_id:{}"
                            .format(worker_endpoints, trainers_num,
                                    current_endpoint, trainer_id))

        # prepare nccl2 env.
        config = fluid.DistributeTranspilerConfig()
        config.mode = "nccl2"
        t = fluid.DistributeTranspiler(config=config)
        t.transpile(
            trainer_id,
            trainers=worker_endpoints_env,
            current_endpoint=current_endpoint,
            program=train_program,
            startup_program=startup_prog)
        nccl2_num_trainers = trainers_num
        nccl2_trainer_id = trainer_id

    exe = fluid.Executor(place)
    exe.run(startup_prog)

    if args.init_checkpoint and args.init_checkpoint != "":
        init_checkpoint(exe, args.init_checkpoint, train_program, args.use_fp16)

    data_reader = DataReader(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
304
        in_tokens=args.in_tokens,
Y
Yibing Liu 已提交
305 306 307 308 309 310 311
        vocab_path=args.vocab_path,
        voc_size=bert_config['vocab_size'],
        epoch=args.epoch,
        max_seq_len=args.max_seq_len,
        generate_neg_sample=args.generate_neg_sample)

    exec_strategy = fluid.ExecutionStrategy()
312
    exec_strategy.use_experimental_executor = args.use_fast_executor
Y
Yibing Liu 已提交
313
    exec_strategy.num_threads = dev_count
314
    exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
Y
Yibing Liu 已提交
315

B
baojun 已提交
316 317 318 319 320 321 322 323 324 325 326 327
    # use_ngraph is for CPU only, please refer to README_ngraph.md for details
    use_ngraph = os.getenv('FLAGS_use_ngraph')
    if not use_ngraph:
        train_exe = fluid.ParallelExecutor(
            use_cuda=args.use_cuda,
            loss_name=total_loss.name,
            exec_strategy=exec_strategy,
            main_program=train_program,
            num_trainers=nccl2_num_trainers,
            trainer_id=nccl2_trainer_id)
    else:
        train_exe = exe
Y
Yibing Liu 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346

    if args.validation_set_dir and args.validation_set_dir != "":
        predict = predict_wrapper(
            args,
            exe,
            bert_config,
            test_prog=test_prog,
            pyreader=test_pyreader,
            fetch_list=[
                next_sent_acc.name, mask_lm_loss.name, total_loss.name
            ])

    train_pyreader.decorate_tensor_provider(data_reader.data_generator())
    train_pyreader.start()
    steps = 0
    cost = []
    lm_cost = []
    acc = []
    time_begin = time.time()
T
tianxin04 已提交
347
    while steps < args.num_train_steps:
Y
Yibing Liu 已提交
348 349 350 351 352
        try:
            steps += nccl2_num_trainers
            skip_steps = args.skip_steps * nccl2_num_trainers

            if nccl2_trainer_id != 0:
B
baojun 已提交
353 354 355 356
                if use_ngraph:
                    train_exe.run(fetch_list=[], program=train_program)
                else:
                    train_exe.run(fetch_list=[])
Y
Yibing Liu 已提交
357 358 359
                continue

            if steps % skip_steps != 0:
B
baojun 已提交
360 361 362 363 364
                if use_ngraph:
                    train_exe.run(fetch_list=[], program=train_program)
                else:
                    train_exe.run(fetch_list=[])

Y
Yibing Liu 已提交
365
            else:
B
baojun 已提交
366 367 368 369 370 371 372 373 374 375 376
                if use_ngraph:
                    each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
                        fetch_list=[
                            next_sent_acc.name, mask_lm_loss.name, total_loss.name,
                            scheduled_lr.name], program=train_program)
                else:
                    each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run(
                        fetch_list=[
                            next_sent_acc.name, mask_lm_loss.name, total_loss.name,
                            scheduled_lr.name])

Y
Yibing Liu 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
                acc.extend(each_next_acc)
                lm_cost.extend(each_mask_lm_cost)
                cost.extend(each_total_cost)

                print("feed_queue size", train_pyreader.queue.size())
                time_end = time.time()
                used_time = time_end - time_begin
                epoch, current_file_index, total_file, current_file = data_reader.get_progress(
                )
                print("current learning_rate:%f" % np_lr[0])
                print("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                      "ppl: %f, next_sent_acc: %f, speed: %f steps/s, file: %s"
                      % (epoch, current_file_index, total_file, steps,
                         np.mean(np.array(cost)),
                         np.mean(np.exp(np.array(lm_cost))),
                         np.mean(np.array(acc)), skip_steps / used_time,
                         current_file))
                cost = []
                lm_cost = []
                acc = []
                time_begin = time.time()

            if steps % args.save_steps == 0:
                save_path = os.path.join(args.checkpoints, "step_" + str(steps))
                fluid.io.save_persistables(exe, save_path, train_program)

            if args.validation_set_dir and steps % args.validation_steps == 0:
                vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
                )
                print("[validation_set] epoch: %d, step: %d, "
                      "loss: %f, global ppl: %f, batch-averged ppl: %f, "
                      "next_sent_acc: %f, speed: %f steps/s" %
                      (epoch, steps,
                       np.mean(np.array(vali_cost) / vali_steps),
                       np.exp(np.mean(np.array(vali_lm_cost) / vali_steps)),
                       np.mean(np.exp(np.array(vali_lm_cost) / vali_steps)),
                       np.mean(np.array(vali_acc) / vali_steps), vali_speed))

        except fluid.core.EOFException:
            train_pyreader.reset()
            break

if __name__ == '__main__':
    print_arguments(args)
    if args.do_test:
        test(args)
    else:
        train(args)