train.py 9.4 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12
import os
import time
import sys

import argparse
import logging
import numpy as np
import yaml
from attrdict import AttrDict
from pprint import pprint

import paddle
13
import paddle.distributed.fleet as fleet
Z
Zeyu Chen 已提交
14 15
import paddle.distributed as dist

L
liu zhengxi 已提交
16
from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
Z
Zeyu Chen 已提交
17 18 19

sys.path.append("../")
import reader
L
liu zhengxi 已提交
20
from util.record import AverageStatistical
Z
Zeyu Chen 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        default="../configs/transformer.big.yaml",
        type=str,
        help="Path of the config file. ")
    args = parser.parse_args()
    return args


def do_train(args):
    paddle.enable_static()
40 41 42
    if args.is_distributed:
        fleet.init(is_collective=True)
        gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
L
liu zhengxi 已提交
43 44
        places = paddle.CUDAPlace(
            gpu_id) if args.use_gpu else paddle.static.cpu_places()
45 46
        trainer_count = 1 if args.use_gpu else len(places)
    else:
L
liu zhengxi 已提交
47 48
        places = paddle.static.cuda_places(
        ) if args.use_gpu else paddle.static.cpu_places()
49
        trainer_count = len(places)
Z
Zeyu Chen 已提交
50 51 52 53 54 55 56

    # Set seed for CE
    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    # Define data loader
L
Leo Chen 已提交
57
    (train_loader), (eval_loader) = reader.create_data_loader(args, places)
Z
Zeyu Chen 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    train_program = paddle.static.Program()
    startup_program = paddle.static.Program()
    with paddle.static.program_guard(train_program, startup_program):
        src_word = paddle.static.data(
            name="src_word", shape=[None, None], dtype="int64")
        trg_word = paddle.static.data(
            name="trg_word", shape=[None, None], dtype="int64")
        lbl_word = paddle.static.data(
            name="lbl_word", shape=[None, None, 1], dtype="int64")

        # Define model
        transformer = TransformerModel(
            src_vocab_size=args.src_vocab_size,
            trg_vocab_size=args.trg_vocab_size,
            max_length=args.max_length + 1,
            n_layer=args.n_layer,
            n_head=args.n_head,
            d_model=args.d_model,
            d_inner_hid=args.d_inner_hid,
            dropout=args.dropout,
            weight_sharing=args.weight_sharing,
            bos_id=args.bos_idx,
            eos_id=args.eos_idx)
        # Define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps, args.bos_idx)

        logits = transformer(src_word=src_word, trg_word=trg_word)

        sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

        scheduler = paddle.optimizer.lr.NoamDecay(
            args.d_model, args.warmup_steps, args.learning_rate, last_epoch=0)

        # Define optimizer
        optimizer = paddle.optimizer.Adam(
            learning_rate=scheduler,
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
            parameters=transformer.parameters())

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        if args.is_distributed:
            build_strategy = paddle.static.BuildStrategy()
            exec_strategy = paddle.static.ExecutionStrategy()
            dist_strategy = fleet.DistributedStrategy()
            dist_strategy.build_strategy = build_strategy
            dist_strategy.execution_strategy = exec_strategy
            dist_strategy.fuse_grad_size_in_MB = 16

            if args.use_amp:
                dist_strategy.amp = True
                dist_strategy.amp_configs = {
                    'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
                    'init_loss_scaling': args.scale_loss,
                }

L
liu zhengxi 已提交
115 116
            optimizer = fleet.distributed_optimizer(
                optimizer, strategy=dist_strategy)
Z
Zhang Ting 已提交
117 118 119 120 121 122 123 124 125 126 127
        else:
            if args.use_amp:
                amp_list = paddle.static.amp.AutoMixedPrecisionLists(
                    custom_white_list=['softmax', 'layer_norm'],
                    custom_black_list=['lookup_table_v2'])
                optimizer = paddle.static.amp.decorate(
                    optimizer,
                    amp_list,
                    init_loss_scaling=args.scale_loss,
                    use_dynamic_loss_scaling=True,
                    use_pure_fp16=args.use_pure_fp16)
Z
Zeyu Chen 已提交
128 129
        optimizer.minimize(avg_cost)

130 131 132 133
    if args.is_distributed:
        exe = paddle.static.Executor(places)
    else:
        exe = paddle.static.Executor()
L
liu zhengxi 已提交
134 135
        build_strategy = paddle.static.BuildStrategy()
        exec_strategy = paddle.static.ExecutionStrategy()
136

L
liu zhengxi 已提交
137 138 139 140
        compiled_train_program = paddle.static.CompiledProgram(
            train_program).with_data_parallel(
                loss_name=avg_cost.name,
                build_strategy=build_strategy,
141
                exec_strategy=exec_strategy)
Z
Zeyu Chen 已提交
142 143
    exe.run(startup_program)

Z
Zhang Ting 已提交
144 145 146
    if not args.is_distributed and args.use_amp:
        optimizer.amp_init(places[0])

Z
Zeyu Chen 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    # the best cross-entropy value with label smoothing
    loss_normalizer = -(
        (1. - args.label_smooth_eps) * np.log(
            (1. - args.label_smooth_eps)) + args.label_smooth_eps *
        np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))

    step_idx = 0

    # For benchmark
    reader_cost_avg = AverageStatistical()
    batch_cost_avg = AverageStatistical()
    batch_ips_avg = AverageStatistical()

    for pass_id in range(args.epoch):
        batch_id = 0
        batch_start = time.time()
        pass_start_time = batch_start
L
liu zhengxi 已提交
164
        for data in train_loader:
Z
Zeyu Chen 已提交
165 166 167
            # NOTE: used for benchmark and use None as default.
            if args.max_iter and step_idx == args.max_iter:
                return
L
liu zhengxi 已提交
168 169
            if trainer_count == 1:
                data = [data]
Z
Zeyu Chen 已提交
170 171
            train_reader_cost = time.time() - batch_start

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
            if args.is_distributed:
                outs = exe.run(train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
            else:
                outs = exe.run(compiled_train_program,
                               feed=[{
                                   'src_word': data[i][0],
                                   'trg_word': data[i][1],
                                   'lbl_word': data[i][2],
                               } for i in range(trainer_count)],
                               fetch_list=[sum_cost.name, token_num.name])
Z
Zeyu Chen 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
            scheduler.step()

            train_batch_cost = time.time() - batch_start
            reader_cost_avg.record(train_reader_cost)
            batch_cost_avg.record(train_batch_cost)
            batch_ips_avg.record(train_batch_cost, np.asarray(outs[1]).sum())

            if step_idx % args.print_step == 0:
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
                    1])
                # Sum the cost from multi-devices
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num

                if step_idx == 0:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)])))
                else:
                    train_avg_batch_cost = args.print_step / batch_cost_avg.get_total_time(
                    )
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                        "normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
                        "batch_cost: %.5f sec, reader_cost: %.5f sec, tokens: %d, "
                        "ips: %.5f words/sec" %
                        (step_idx, pass_id, batch_id, total_avg_cost,
                         total_avg_cost - loss_normalizer,
                         np.exp([min(total_avg_cost, 100)]),
                         train_avg_batch_cost, batch_cost_avg.get_average(),
                         reader_cost_avg.get_average(),
                         batch_ips_avg.get_total_cnt(),
                         batch_ips_avg.get_average_per_sec()))
                reader_cost_avg.reset()
                batch_cost_avg.reset()
                batch_ips_avg.reset()

            if step_idx % args.save_step == 0 and step_idx != 0:
230
                if args.save_model and dist.get_rank() == 0:
Z
Zeyu Chen 已提交
231 232
                    model_path = os.path.join(
                        args.save_model, "step_" + str(step_idx), "transformer")
L
liu zhengxi 已提交
233
                    paddle.static.save(train_program, model_path)
Z
Zeyu Chen 已提交
234 235 236 237 238

            batch_id += 1
            step_idx += 1
            batch_start = time.time()

L
liu zhengxi 已提交
239 240 241 242
    if args.save_model and dist.get_rank() == 0:
        model_path = os.path.join(args.save_model, "step_final", "transformer")
        paddle.static.save(train_program, model_path)

Z
Zeyu Chen 已提交
243 244 245 246 247 248 249 250 251 252 253
    paddle.disable_static()


if __name__ == "__main__":
    ARGS = parse_args()
    yaml_file = ARGS.config
    with open(yaml_file, 'rt') as f:
        args = AttrDict(yaml.safe_load(f))
        pprint(args)

    do_train(args)