run_pretrain.py 13.9 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16 17
import collections
import itertools
Z
Zeyu Chen 已提交
18 19 20 21 22 23 24 25
import os
import random
import time
import h5py
from functools import partial
from concurrent.futures import ThreadPoolExecutor

import numpy as np
26
import distutils.util
Z
Zeyu Chen 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40

import paddle
import paddle.distributed.fleet as fleet
from paddle.io import DataLoader, Dataset

from paddlenlp.transformers import BertForPretraining, BertModel, BertPretrainingCriterion
from paddlenlp.transformers import BertTokenizer
from data import create_data_holder, create_pretraining_dataset

MODEL_CLASSES = {"bert": (BertForPretraining, BertTokenizer)}


def parse_args():
    parser = argparse.ArgumentParser()
41 42 43 44 45
    parser.add_argument(
        "--select_device",
        default="gpu",
        type=str,
        help="The device that selecting for the training, must be gpu/xpu.")
Z
Zeyu Chen 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()), )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(
            sum([
                list(classes[-1].pretrained_init_configuration.keys())
                for classes in MODEL_CLASSES.values()
            ], [])), )
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input directory where the data will be read from.", )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--max_predictions_per_seq",
        default=80,
        type=int,
        help="The maximum total of masked tokens in input sequence")

    parser.add_argument(
        "--batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for training.", )
    parser.add_argument(
        "--learning_rate",
        default=5e-5,
        type=float,
        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--weight_decay",
        default=0.0,
        type=float,
        help="Weight decay if we apply some.")
    parser.add_argument(
        "--adam_epsilon",
        default=1e-8,
        type=float,
        help="Epsilon for Adam optimizer.")
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument(
        "--warmup_steps",
        default=0,
        type=int,
        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--logging_steps",
        type=int,
        default=500,
        help="Log every X updates steps.")
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for initialization")
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    parser.add_argument(
        "--use_amp",
        type=distutils.util.strtobool,
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--enable_addto",
        type=distutils.util.strtobool,
        default=False,
        help="Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training."
    )
    parser.add_argument(
        "--scale_loss",
        type=float,
        default=1.0,
        help="The value of scale_loss for fp16.")
Z
Zeyu Chen 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    args = parser.parse_args()
    return args


def select_dataset_file_for_each_worker(files, f_start_id, worker_num,
                                        worker_index):
    num_files = len(files)
    if worker_num > num_files:
        remainder = worker_num % num_files
        data_file = files[(
            f_start_id * worker_num + worker_index + remainder * f_start_id) %
                          num_files]
    else:
        data_file = files[(f_start_id * worker_num + worker_index) % num_files]
    return data_file


def reset_program_state_dict(model, state_dict):
    scale = model.initializer_range if hasattr(model, "initializer_range")\
        else model.bert.config["initializer_range"]

    new_state_dict = dict()
    for n, p in state_dict.items():
        if "layer_norm" not in p.name:
            dtype_str = "float32"
            if str(p.dtype) == "VarType.FP64":
                dtype_str = "float64"
            new_state_dict[p.name] = np.random.normal(
                loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
    return new_state_dict

W
WangXi 已提交
175 176 177 178 179 180 181 182 183 184
def create_strategy():
    """
    Create build strategy and exec strategy.
    Args:

    Returns:
        build_strategy: build strategy
        exec_strategy: exec strategy
    """
    build_strategy = paddle.static.BuildStrategy()
185
    exec_strategy = paddle.static.ExecutionStrategy()
W
WangXi 已提交
186 187 188

    build_strategy.enable_addto = args.enable_addto

189 190
    exec_strategy.num_threads = 1
    exec_strategy.num_iteration_per_drop_scope = 10000
W
WangXi 已提交
191 192 193 194 195
    return build_strategy, exec_strategy


def build_compiled_program(main_program, loss):
    build_strategy, exec_strategy = create_strategy()
196 197 198 199 200 201 202 203
    main_program = paddle.static.CompiledProgram(
        main_program).with_data_parallel(
            loss_name=loss.name,
            exec_strategy=exec_strategy,
            build_strategy=build_strategy)
    return main_program


W
WangXi 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
def dist_optimizer(args, optimizer):
    """
    Create a distributed optimizer based on a normal optimizer
    Args:
        args:
        optimizer: a normal optimizer
    Returns:
        optimizer: a distributed optimizer
    """
    build_strategy, exec_strategy = create_strategy()

    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.execution_strategy = exec_strategy
    dist_strategy.build_strategy = build_strategy

    dist_strategy.fuse_grad_size_in_MB = 16
    if args.use_amp:
        dist_strategy.amp = True
        dist_strategy.amp_configs = {
            'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
            'init_loss_scaling': args.scale_loss,
        }

    optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
    return optimizer


231 232 233 234 235 236
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)


Z
Zeyu Chen 已提交
237 238 239 240 241 242 243 244 245 246 247 248
class WorkerInitObj(object):
    def __init__(self, seed):
        self.seed = seed

    def __call__(self, id):
        np.random.seed(seed=self.seed + id)
        random.seed(self.seed + id)


def do_train(args):
    # Initialize the paddle and paddle fleet execute enviroment
    paddle.enable_static()
249
    place = paddle.set_device(args.select_device)
Z
Zeyu Chen 已提交
250 251
    fleet.init(is_collective=True)

W
WangXi 已提交
252 253 254
    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

Z
Zeyu Chen 已提交
255 256
    # Create the random seed for the worker
    set_seed(args.seed)
W
WangXi 已提交
257
    worker_init = WorkerInitObj(args.seed + worker_index)
Z
Zeyu Chen 已提交
258 259

    # Define the input data in the static mode
260 261
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
Z
Zeyu Chen 已提交
262 263 264 265 266 267 268 269 270 271 272
    data_holders = create_data_holder(args)

    [
        input_ids, segment_ids, input_mask, masked_lm_positions,
        masked_lm_labels, next_sentence_labels, masked_lm_scale
    ] = data_holders

    # Define the model structure in static mode
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
273 274 275 276
    config = model_class.pretrained_init_configuration[args.model_name_or_path]
    if config["vocab_size"] % 8 != 0:
        config["vocab_size"] += 8 - (config["vocab_size"] % 8)
    model = BertForPretraining(BertModel(**config))
Z
Zeyu Chen 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
    criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
    prediction_scores, seq_relationship_score = model(
        input_ids=input_ids,
        token_type_ids=segment_ids,
        attention_mask=input_mask,
        masked_positions=masked_lm_positions)
    loss = criterion(prediction_scores, seq_relationship_score,
                     masked_lm_labels, next_sentence_labels, masked_lm_scale)

    # Define the dynamic learing_reate scheduler and optimizer
    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))))

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])
W
WangXi 已提交
307
    if worker_num == 1 and args.use_amp:
308
        amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
309
            custom_white_list=['softmax', 'layer_norm', 'gelu'])
310 311 312 313 314
        optimizer = paddle.fluid.contrib.mixed_precision.decorate(
            optimizer,
            amp_list,
            init_loss_scaling=args.scale_loss,
            use_dynamic_loss_scaling=True)
W
WangXi 已提交
315 316 317 318

    if worker_num > 1:
        # Use the fleet api to compile the distributed optimizer
        optimizer = dist_optimizer(args, optimizer)
Z
Zeyu Chen 已提交
319 320 321 322
    optimizer.minimize(loss)

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
323
    exe.run(startup_program)
Z
Zeyu Chen 已提交
324 325 326 327
    state_dict = model.state_dict()

    # Use the state dict to update the parameter
    reset_state_dict = reset_program_state_dict(model, state_dict)
328
    paddle.static.set_program_state(main_program, reset_state_dict)
W
WangXi 已提交
329 330 331 332

    if worker_num == 1:
        # Construct the compiled program
        main_program = build_compiled_program(main_program, loss)
Z
Zeyu Chen 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365

    pool = ThreadPoolExecutor(1)
    global_step = 0
    tic_train = time.time()
    epoch = 0
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in
            f
        ]
        files.sort()
        num_files = len(files)
        random.Random(args.seed + epoch).shuffle(files)
        f_start_id = 0

        # Select one file for each worker and create the DataLoader for the file
        data_file = select_dataset_file_for_each_worker(
            files, f_start_id, worker_num, worker_index)
        train_data_loader, _ = create_pretraining_dataset(
            data_file, args.max_predictions_per_seq, args, data_holders,
            worker_init, paddle.static.cuda_places())

        for f_id in range(f_start_id + 1, len(files)):
            data_file = select_dataset_file_for_each_worker(
                files, f_id, worker_num, worker_index)
            dataset_future = pool.submit(create_pretraining_dataset, data_file,
                                         args.max_predictions_per_seq, args,
                                         data_holders, worker_init,
                                         paddle.static.cuda_places())

            for step, batch in enumerate(train_data_loader):
                global_step += 1
366 367 368
                loss_return = exe.run(main_program,
                                      feed=batch,
                                      fetch_list=[loss])
Z
Zeyu Chen 已提交
369 370 371 372 373
                # In the new 2.0 api, must call this function to change the learning_rate
                lr_scheduler.step()
                if global_step % args.logging_steps == 0:
                    time_cost = time.time() - tic_train
                    print(
374
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, ips: %.2f sequences/s"
Z
Zeyu Chen 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
                        % (global_step, epoch, step, loss_return[0],
                           args.logging_steps / time_cost,
                           args.logging_steps * args.batch_size / time_cost))
                    tic_train = time.time()
                if global_step % args.save_steps == 0:
                    if worker_index == 0:
                        output_dir = os.path.join(args.output_dir,
                                                  "model_%d" % global_step)
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # TODO(fangzeyang): Udpate the save_params to paddle.static
                        paddle.fluid.io.save_params(exe, output_dir)
                        tokenizer.save_pretrained(output_dir)
                if global_step >= args.max_steps:
                    del train_data_loader
                    return
            del train_data_loader
            train_data_loader, data_file = dataset_future.result(timeout=None)
        epoch += 1


if __name__ == "__main__":
    args = parse_args()
    do_train(args)