train.py 18.6 KB
Newer Older
P
phlrain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import time
import os
import random
import math
24
import contextlib
25
from distutils.dir_util import mkpath
P
phlrain 已提交
26 27
import paddle
import paddle.fluid as fluid
28
from paddle.fluid import profiler
P
phlrain 已提交
29
import paddle.fluid.framework as framework
30
import paddle.fluid.profiler as profiler
P
phlrain 已提交
31 32 33 34 35 36 37 38
from paddle.fluid.executor import Executor

import reader

import sys
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding("utf-8")
P
pkpk 已提交
39
sys.path.append('../shared_modules/')
P
phlrain 已提交
40 41 42 43
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from args import *
44
from models.model_check import check_cuda, check_version
Y
Yibing Liu 已提交
45
from models.language_model import lm_model
46
from config import RNNConfig
P
phlrain 已提交
47 48 49 50 51
import logging
import pickle

SEED = 123

52 53 54 55 56 57 58 59 60 61 62 63 64
class TimeCostAverage(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.cnt = 0
        self.total_time = 0
    def record(self, usetime):
        self.cnt += 1
        self.total_time += usetime
    def get_average(self):
        if self.cnt == 0:
            return 0
        return self.total_time / self.cnt
P
phlrain 已提交
65

66
@contextlib.contextmanager
67
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
68
    if profile:
69
        with profiler.profiler('All', 'total', profiler_path):
70 71 72 73 74
            yield
    else:
        yield


P
phlrain 已提交
75
def get_current_model_para(train_prog, train_exe):
L
Li Fuchen 已提交
76
    param_list = train_prog.all_parameters()
P
phlrain 已提交
77 78 79 80 81 82 83 84 85 86 87 88
    param_name_list = [p.name for p in param_list]

    vals = {}
    for p_name in param_name_list:
        p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
        vals[p_name] = p_array

    return vals


def save_para_npz(train_prog, train_exe):
    print("begin to save model to model_base")
L
Li Fuchen 已提交
89
    param_list = train_prog.all_parameters()
P
phlrain 已提交
90 91 92 93 94 95 96 97 98 99 100 101
    param_name_list = [p.name for p in param_list]

    vals = {}
    for p_name in param_name_list:
        p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
        vals[p_name] = p_array

    emb = vals["embedding_para"]
    print("begin to save model to model_base")
    np.savez("mode_base", **vals)


102
def main():
P
phlrain 已提交
103
    args = parse_args()
104

105
    # check if set use_gpu=True in paddlepaddle cpu version
106
    check_cuda(args.use_gpu)
107 108
    # check if paddlepaddle version is satisfied
    check_version()
109

P
phlrain 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    logger = logging.getLogger("lm")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    if args.log_path:
        file_handler = logging.FileHandler(args.log_path)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    else:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)
    logger.info('Running with args : {}'.format(args))

126 127
    config = RNNConfig(args)

128 129 130
    if not os.path.exists(args.save_model_dir):
        mkpath(args.save_model_dir)

131 132 133 134
    # define train program
    main_program = fluid.Program()
    startup_program = fluid.Program()
    if args.enable_ce:
135
        startup_program.random_seed, main_program.random_seed = SEED, SEED
136 137 138 139 140 141 142 143 144 145
    with fluid.program_guard(main_program, startup_program):
        with fluid.unique_name.guard():
            res_vars = lm_model.lm_model(
                config.hidden_size,
                config.vocab_size,
                num_layers=config.num_layers,
                num_steps=config.num_steps,
                init_scale=config.init_scale,
                dropout=config.dropout,
                rnn_model=config.rnn_model,
146
                use_dataloader=args.use_dataloader)
147

148 149
            if args.use_dataloader:
                dataloader = res_vars[-1]
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
                res_vars = res_vars[:-1]
            loss, last_hidden, last_cell, feed_order = res_vars

            fluid.clip.set_gradient_clip(
                clip=fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=config.max_grad_norm))

            learning_rate = fluid.layers.create_global_var(
                name="learning_rate",
                shape=[1],
                value=1.0,
                dtype='float32',
                persistable=True)

            optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
            optimizer.minimize(loss)

    # define inference program
    inference_program = fluid.Program()
    inference_startup_program = fluid.Program()
170
    inference_program.random_seed, inference_startup_program.radom_seed = SEED, SEED
171 172 173 174 175 176 177 178 179 180
    with fluid.program_guard(inference_program, inference_startup_program):
        with fluid.unique_name.guard():
            lm_model.lm_model(
                config.hidden_size,
                config.vocab_size,
                num_layers=config.num_layers,
                num_steps=config.num_steps,
                init_scale=config.init_scale,
                dropout=config.dropout,
                rnn_model=config.rnn_model,
181
                use_dataloader=False)
182 183 184
    # Some op behaves differently for train and inference, we need to call
    # this clone function to ensure every op is right for inference.
    inference_program = inference_program.clone(for_test=True)
P
phlrain 已提交
185

Y
Yibing Liu 已提交
186
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
P
phlrain 已提交
187
    exe = Executor(place)
188 189
    exe.run(startup_program)

190 191 192 193 194
    if args.init_from_pretrain_model:
        if not os.path.exists(args.init_from_pretrain_model + '.pdparams'):
            print(args.init_from_pretrain_model)
            raise Warning("The pretrained params do not exist.")
            return
195
        fluid.load(main_program, args.init_from_pretrain_model, exe)
196 197 198
        print("finish initing model from pretrained params from %s" %
              (args.init_from_pretrain_model))

199 200 201 202 203 204 205 206 207
    device_count = len(fluid.cuda_places()) if args.use_gpu else len(
        fluid.cpu_places())

    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.num_threads = device_count
    exec_strategy.num_iteration_per_drop_scope = 100

    build_strategy = fluid.BuildStrategy()
    build_strategy.fuse_all_optimizer_ops = True
208 209 210 211 212 213
    try:
        fluid.require_version(min_version='1.7.0')
        build_strategy.enable_auto_fusion = args.enable_auto_fusion
    except Exception as e:
        logger.info("PaddlePaddle version 1.7.0 or higher is "
                    "required when you want to enable fusion_group.")
214 215 216 217 218 219 220 221 222

    if args.parallel:
        train_program = fluid.compiler.CompiledProgram(
            main_program).with_data_parallel(
                loss_name=loss.name,
                build_strategy=build_strategy,
                exec_strategy=exec_strategy)
    else:
        train_program = fluid.compiler.CompiledProgram(main_program)
P
phlrain 已提交
223

224
    train_program.random_seed = SEED
P
phlrain 已提交
225 226
    data_path = args.data_path
    print("begin to load data")
H
Hongyu Liu 已提交
227
    ptb_data = reader.get_ptb_data(data_path)
P
phlrain 已提交
228
    print("finished load data")
H
Hongyu Liu 已提交
229
    train_data, valid_data, test_data = ptb_data
P
phlrain 已提交
230

231
    def generate_init_data():
232
        batch_size = config.batch_size * device_count
233
        init_hidden = np.zeros(
234
            (batch_size, config.num_layers, config.hidden_size),
235 236
            dtype='float32')
        init_cell = np.zeros(
237
            (batch_size, config.num_layers, config.hidden_size),
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
            dtype='float32')
        return init_hidden, init_cell

    def generate_new_lr(epoch_id=0, device_count=1):
        new_lr = config.base_learning_rate * (config.lr_decay**max(
            epoch_id + 1 - config.epoch_start_decay, 0.0))
        lr = np.ones((device_count), dtype='float32') * new_lr
        return lr

    def prepare_input(batch,
                      init_hidden=None,
                      init_cell=None,
                      epoch_id=0,
                      with_lr=True,
                      device_count=1):
P
phlrain 已提交
253
        x, y = batch
254
        x = x.reshape((-1, config.num_steps, 1))
P
phlrain 已提交
255 256
        y = y.reshape((-1, 1))

257
        res = {}
P
phlrain 已提交
258 259
        res['x'] = x
        res['y'] = y
260 261 262 263
        if init_hidden is not None:
            res['init_hidden'] = init_hidden
        if init_cell is not None:
            res['init_cell'] = init_cell
P
phlrain 已提交
264
        if with_lr:
265
            res['learning_rate'] = generate_new_lr(epoch_id, device_count)
P
phlrain 已提交
266 267 268 269 270

        return res

    def eval(data):
        # when eval the batch_size set to 1
271 272
        eval_data_iter = reader.get_data_iter(data, config.batch_size *
                                              device_count, config.num_steps)
P
phlrain 已提交
273 274
        total_loss = 0.0
        iters = 0
275
        init_hidden, init_cell = generate_init_data()
P
phlrain 已提交
276 277
        for batch_id, batch in enumerate(eval_data_iter):
            input_data_feed = prepare_input(
278
                batch, init_hidden, init_cell, epoch_id=0, with_lr=False)
P
phlrain 已提交
279
            fetch_outs = exe.run(
280
                program=inference_program,
P
phlrain 已提交
281
                feed=input_data_feed,
L
liuhongyu 已提交
282
                fetch_list=[loss.name, last_hidden.name, last_cell.name],
H
Hongyu Liu 已提交
283
                use_program_cache=False)
P
phlrain 已提交
284

285
            cost_eval = np.array(fetch_outs[0])
P
phlrain 已提交
286 287 288
            init_hidden = np.array(fetch_outs[1])
            init_cell = np.array(fetch_outs[2])

289 290
            total_loss += cost_eval
            iters += config.num_steps
P
phlrain 已提交
291 292 293 294

        ppl = np.exp(total_loss / iters)
        return ppl

295 296 297 298 299
    def get_log_interval(data_len):
        num_batchs = data_len // config.batch_size
        epoch_size = (num_batchs - 1) // config.num_steps
        log_interval = max(1, epoch_size // 10)
        return log_interval
P
phlrain 已提交
300

301 302 303
    def train_an_epoch(epoch_id, batch_times):
        # get train epoch size
        log_interval = get_log_interval(len(train_data))
304 305
        train_data_iter = reader.get_data_iter(train_data, config.batch_size *
                                               device_count, config.num_steps)
P
phlrain 已提交
306 307 308

        total_loss = 0
        iters = 0
309
        batch_cost_avg = TimeCostAverage()
H
Hongyu Liu 已提交
310 311

        init_hidden, init_cell = generate_init_data()
312
        batch_start_time = time.time()
P
phlrain 已提交
313 314
        for batch_id, batch in enumerate(train_data_iter):
            input_data_feed = prepare_input(
315 316 317 318 319 320 321 322
                batch,
                init_hidden=init_hidden,
                init_cell=init_cell,
                epoch_id=epoch_id,
                with_lr=True,
                device_count=device_count)
            fetch_outs = exe.run(train_program,
                                 feed=input_data_feed,
323 324 325 326
                                 fetch_list=[
                                     loss.name, "learning_rate",
                                     last_hidden.name, last_cell.name
                                 ],
P
phlrain 已提交
327
                                 use_program_cache=True)
328 329
            batch_time = time.time() - batch_start_time
            batch_times.append(batch_time)
330
            batch_cost_avg.record(batch_time)
P
phlrain 已提交
331 332

            cost_train = np.array(fetch_outs[0])
333
            lr = np.array(fetch_outs[1])
H
Hongyu Liu 已提交
334 335
            init_hidden = np.array(fetch_outs[2])
            init_cell = np.array(fetch_outs[3])
P
phlrain 已提交
336
            total_loss += cost_train
337
            iters += config.num_steps
P
phlrain 已提交
338 339
            if batch_id > 0 and batch_id % log_interval == 0:
                ppl = np.exp(total_loss / iters)
340 341
                print(
                    "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
342 343
                    % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
                batch_cost_avg.reset()
L
Li Fuchen 已提交
344

345 346 347 348 349
            # profiler tools for benchmark
            if args.profile and batch_id == log_interval:
                profiler.reset_profiler()
            elif args.profile and batch_id == (log_interval + 5):
                break
350 351 352

            batch_start_time = time.time()

P
phlrain 已提交
353
        ppl = np.exp(total_loss / iters)
354
        return ppl
P
phlrain 已提交
355

356
    def train_an_epoch_dataloader(epoch_id, batch_times):
357 358
        # get train epoch size
        log_interval = get_log_interval(len(train_data))
P
phlrain 已提交
359

360
        init_hidden, init_cell = generate_init_data()
Z
zhengya01 已提交
361

362 363
        total_loss = 0
        iters = 0
364
        batch_cost_avg = TimeCostAverage()
365

366
        dataloader.start()
367 368 369 370 371 372 373 374 375 376 377
        batch_id = 0
        try:
            while True:
                data_feeds = {}
                if batch_id == 0:
                    batch_time = 0
                    batch_start_time = time.time()
                else:
                    batch_time = time.time() - batch_start_time
                    batch_times.append(batch_time)
                    batch_start_time = time.time()
378
                    batch_cost_avg.record(batch_time)
379 380 381

                new_lr = generate_new_lr(epoch_id, device_count)
                data_feeds['learning_rate'] = new_lr
H
Hongyu Liu 已提交
382 383
                data_feeds["init_hidden"] = init_hidden
                data_feeds["init_cell"] = init_cell
384 385 386

                fetch_outs = exe.run(train_program,
                                     feed=data_feeds,
387 388 389 390
                                     fetch_list=[
                                         loss.name, "learning_rate",
                                         last_hidden.name, last_cell.name
                                     ],
391 392 393 394
                                     use_program_cache=True)

                cost_train = np.array(fetch_outs[0])
                lr = np.array(fetch_outs[1])
395 396
                init_hidden = np.array(fetch_outs[2])
                init_cell = np.array(fetch_outs[3])
397 398 399 400 401 402 403 404

                total_loss += cost_train
                iters += config.num_steps
                if batch_id > 0 and (log_interval == 0 or
                                     batch_id % log_interval == 0):
                    ppl = np.exp(total_loss / iters)
                    print(
                        "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
405 406
                        % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
                    batch_cost_avg.reset()
407 408

                batch_id += 1
409 410 411 412 413
                # profiler tools for benchmark
                if args.profile and batch_id == log_interval:
                    profiler.reset_profiler()
                elif args.profile and batch_id == (log_interval + 5):
                    break
414
        except fluid.core.EOFException:
415
            dataloader.reset()
416 417 418 419 420 421

        batch_times.append(time.time() - batch_start_time)
        ppl = np.exp(total_loss / iters)
        return ppl

    def train():
422
        if args.use_dataloader:
423 424

            def data_gen():
425
                data_iter_size = config.batch_size
426 427 428 429 430 431 432 433
                train_batches = reader.get_data_iter(train_data, data_iter_size,
                                                     config.num_steps)
                for batch in train_batches:
                    x, y = batch
                    x = x.reshape((-1, config.num_steps, 1))
                    y = y.reshape((-1, 1))
                    yield x, y

434
            dataloader.set_batch_generator(data_gen)
435 436 437 438 439

        total_time = 0.0
        for epoch_id in range(config.max_epoch):
            batch_times = []
            epoch_start_time = time.time()
440 441
            if args.use_dataloader:
                train_ppl = train_an_epoch_dataloader(epoch_id, batch_times)
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
            else:
                train_ppl = train_an_epoch(epoch_id, batch_times)
            epoch_time = time.time() - epoch_start_time
            total_time += epoch_time
            print(
                "\nTrain epoch:[%d]; epoch Time: %.5f; ppl: %.5f; avg_time: %.5f steps/s \n"
                % (epoch_id, epoch_time, train_ppl[0],
                   len(batch_times) / sum(batch_times)))

            # FIXME(zjl): ppl[0] increases as batch_size increases. 
            # We should find a better way to calculate ppl by normalizing batch_size. 
            if device_count == 1 and config.batch_size <= 20 and epoch_id == 0 and train_ppl[
                    0] > 1000:
                # for bad init, after first epoch, the loss is over 1000
                # no more need to continue
                print(
                    "Parameters are randomly initialized and not good this time because the loss is over 1000 after the first epoch."
                )
                print("Abort this training process and please start again.")
                return

            if epoch_id == config.max_epoch - 1 and args.enable_ce:
                # kpis
                print("ptblm\tlstm_language_model_%s_duration_card%d\t%s" %
                      (args.rnn_model, device_count,
                       total_time / config.max_epoch))
                print("ptblm\tlstm_language_model_%s_loss_card%d\t%s" %
                      (args.rnn_model, device_count, train_ppl[0]))

471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
            if not args.profile:
                # NOTE(zjl): sometimes we have not enough data for eval if batch_size is large, i.e., 2100
                # Just skip to avoid error
                def is_valid_data(data, batch_size, num_steps):
                    data_len = len(data)
                    batch_len = data_len // batch_size
                    epoch_size = (batch_len - 1) // num_steps
                    return epoch_size >= 1

                valid_data_valid = is_valid_data(valid_data, config.batch_size,
                                                 config.num_steps)
                if valid_data_valid:
                    valid_ppl = eval(valid_data)
                    print("Valid ppl: %.5f" % valid_ppl[0])
                else:
                    print(
                        'WARNING: length of valid_data is {}, which is not enough for batch_size {} and num_steps {}'.
                        format(
                            len(valid_data), config.batch_size,
                            config.num_steps))

                save_model_dir = os.path.join(args.save_model_dir,
                                              str(epoch_id))
                if not os.path.exists(save_model_dir):
                    mkpath(save_model_dir)
                save_model_dir = os.path.join(save_model_dir, 'params')

                fluid.save(main_program, save_model_dir)
                print("Saved model to: %s.\n" % save_model_dir)
Z
zhengya01 已提交
500

501
    with profile_context(args.profile, args.profiler_path):
502 503
        train()

504 505 506 507 508 509
    test_ppl = eval(test_data)
    print("Test ppl:", test_ppl[0])


if __name__ == '__main__':
    main()
510