program.py 19.1 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
W
WenmuZhou 已提交
24 25 26 27 28 29
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
30 31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
D
dyning 已提交
32 33
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
L
LDOUBLEV 已提交
34
from ppocr.utils import profiler
D
dyning 已提交
35 36
from ppocr.data import build_dataloader
import numpy as np
L
LDOUBLEV 已提交
37

D
dyning 已提交
38

L
LDOUBLEV 已提交
39 40 41 42 43 44 45
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
46 47 48 49 50 51 52
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
            help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
        )
L
LDOUBLEV 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

农夫三拳_'s avatar
农夫三拳_ 已提交
87 88
default_config = {'Global': {'debug': False, }}

L
LDOUBLEV 已提交
89 90 91 92 93 94 95 96

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
农夫三拳_'s avatar
农夫三拳_ 已提交
97
    merge_config(default_config)
L
LDOUBLEV 已提交
98 99
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
W
WenmuZhou 已提交
100
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
L
LDOUBLEV 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    return global_config


def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
119 120 121 122
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
L
LDOUBLEV 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
W
WenmuZhou 已提交
143
        if use_gpu and not paddle.is_compiled_with_cuda():
W
WenmuZhou 已提交
144
            print(err)
L
LDOUBLEV 已提交
145 146 147 148 149
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
150
def train(config,
D
dyning 已提交
151 152 153
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
154 155 156 157 158 159 160 161
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
S
stephon 已提交
162 163
          vdl_writer=None,
          scaler=None):
W
WenmuZhou 已提交
164 165
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
L
LDOUBLEV 已提交
166 167 168 169
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
170
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
171

D
dyning 已提交
172
    global_step = 0
173 174
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
175 176 177 178
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
179 180 181 182 183
        if len(valid_dataloader) == 0:
            logger.info(
                'No Images in eval dataset, evaluation during training will be disabled'
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
184 185 186
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
187 188
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
189 190
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
191 192 193 194
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
195
    model_average = False
W
WenmuZhou 已提交
196 197
    model.train()

T
tink2123 已提交
198
    use_srn = config['Architecture']['algorithm'] == "SRN"
T
tink2123 已提交
199
    extra_input = config['Architecture'][
L
LDOUBLEV 已提交
200
        'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
201
    try:
L
fix bug  
LDOUBLEV 已提交
202
        model_type = config['Architecture']['model_type']
203
    except:
L
fix bug  
LDOUBLEV 已提交
204
        model_type = None
T
tink2123 已提交
205
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
206

W
WenmuZhou 已提交
207 208 209
    if 'start_epoch' in best_model_dict:
        start_epoch = best_model_dict['start_epoch']
    else:
T
tink2123 已提交
210
        start_epoch = 1
W
WenmuZhou 已提交
211

T
tink2123 已提交
212
    for epoch in range(start_epoch, epoch_num + 1):
213 214
        train_dataloader = build_dataloader(
            config, 'Train', device, logger, seed=epoch)
W
WenmuZhou 已提交
215
        train_reader_cost = 0.0
文幕地方's avatar
文幕地方 已提交
216 217 218
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()
J
Jane-Ding 已提交
219 220
        max_iter = len(train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
221
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
222
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
223
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
224
            if idx >= max_iter:
W
WenmuZhou 已提交
225 226 227
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
228
            if use_srn:
T
tink2123 已提交
229
                model_average = True
T
tink2123 已提交
230
            if model_type == 'table' or extra_input:
M
refine  
MissPenguin 已提交
231
                preds = model(images, data=batch[1:])
L
LDOUBLEV 已提交
232 233
            if model_type == "kie":
                preds = model(batch)
S
stephon 已提交
234

文幕地方's avatar
文幕地方 已提交
235
            train_start = time.time()
S
stephon 已提交
236 237 238 239 240 241 242
            # use amp
            if scaler:
                with paddle.amp.auto_cast():
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
                    else:
                        preds = model(images)
T
tink2123 已提交
243
            else:
S
stephon 已提交
244 245 246 247
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
                else:
                    preds = model(images)
W
WenmuZhou 已提交
248 249
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
S
stephon 已提交
250 251 252 253 254 255 256 257

            if scaler:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
            else:
                avg_loss.backward()
                optimizer.step()
W
WenmuZhou 已提交
258
            optimizer.clear_grad()
W
WenmuZhou 已提交
259

文幕地方's avatar
文幕地方 已提交
260 261
            train_run_cost += time.time() - train_start
            total_samples += len(images)
W
WenmuZhou 已提交
262

D
dyning 已提交
263 264
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
265 266 267 268 269 270

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

L
LDOUBLEV 已提交
271
            if cal_metric_during_train:  # only rec and cls need
W
WenmuZhou 已提交
272
                batch = [item.numpy() for item in batch]
L
LDOUBLEV 已提交
273
                if model_type in ['table', 'kie']:
M
MissPenguin 已提交
274 275 276 277
                    eval_class(preds, batch)
                else:
                    post_result = post_process_class(preds, batch[1])
                    eval_class(post_result, batch)
littletomatodonkey's avatar
fix doc  
littletomatodonkey 已提交
278 279
                metric = eval_class.get_metric()
                train_stats.update(metric)
W
WenmuZhou 已提交
280 281 282 283 284 285

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

286 287 288
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
289
                logs = train_stats.log()
W
WenmuZhou 已提交
290
                strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
W
WenmuZhou 已提交
291
                    epoch, epoch_num, global_step, logs, train_reader_cost /
文幕地方's avatar
文幕地方 已提交
292 293 294
                    print_batch_step, (train_reader_cost + train_run_cost) /
                    print_batch_step, total_samples,
                    total_samples / (train_reader_cost + train_run_cost))
W
WenmuZhou 已提交
295
                logger.info(strs)
W
WenmuZhou 已提交
296
                train_reader_cost = 0.0
文幕地方's avatar
文幕地方 已提交
297 298
                train_run_cost = 0.0
                total_samples = 0
W
WenmuZhou 已提交
299 300 301
            # eval
            if global_step > start_eval_step and \
                    (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
T
tink2123 已提交
302 303 304 305 306 307 308
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
309 310 311 312 313
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
314
                    model_type,
T
tink2123 已提交
315
                    extra_input=extra_input)
L
LDOUBLEV 已提交
316 317 318
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
319 320 321

                # logger metric
                if vdl_writer is not None:
L
LDOUBLEV 已提交
322
                    for k, v in cur_metric.items():
W
WenmuZhou 已提交
323 324
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
L
LDOUBLEV 已提交
325 326
                                                  cur_metric[k], global_step)
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
327
                        main_indicator]:
L
LDOUBLEV 已提交
328
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
329 330 331 332 333 334 335 336 337
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
338 339
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
340
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
341 342 343 344 345 346 347 348 349
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
            global_step += 1
T
tink2123 已提交
350
            optimizer.clear_grad()
文幕地方's avatar
文幕地方 已提交
351
            reader_start = time.time()
W
WenmuZhou 已提交
352 353 354 355 356 357 358 359 360
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
361 362
                epoch=epoch,
                global_step=global_step)
W
WenmuZhou 已提交
363 364 365 366 367 368 369 370 371
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
372 373
                epoch=epoch,
                global_step=global_step)
L
LDOUBLEV 已提交
374
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
375 376 377 378
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
L
LDOUBLEV 已提交
379 380 381
    return


M
refine  
MissPenguin 已提交
382 383 384 385
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
386
         model_type=None,
T
tink2123 已提交
387
         extra_input=False):
W
WenmuZhou 已提交
388 389 390 391
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
392 393 394 395 396
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
397 398
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
W
WenmuZhou 已提交
399
        for idx, batch in enumerate(valid_dataloader):
400
            if idx >= max_iter:
W
WenmuZhou 已提交
401
                break
W
fix bug  
WenmuZhou 已提交
402
            images = batch[0]
W
WenmuZhou 已提交
403
            start = time.time()
T
tink2123 已提交
404
            if model_type == 'table' or extra_input:
M
refine  
MissPenguin 已提交
405
                preds = model(images, data=batch[1:])
L
LDOUBLEV 已提交
406 407
            if model_type == "kie":
                preds = model(batch)
X
xiaoting 已提交
408
            else:
L
LDOUBLEV 已提交
409
                preds = model(images)
W
WenmuZhou 已提交
410 411 412 413
            batch = [item.numpy() for item in batch]
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
L
LDOUBLEV 已提交
414
            if model_type in ['table', 'kie']:
M
MissPenguin 已提交
415 416 417 418
                eval_class(preds, batch)
            else:
                post_result = post_process_class(preds, batch[1])
                eval_class(post_result, batch)
L
LDOUBLEV 已提交
419

W
fix bug  
WenmuZhou 已提交
420
            pbar.update(1)
W
WenmuZhou 已提交
421
            total_frame += len(images)
L
LDOUBLEV 已提交
422 423
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
424

W
fix bug  
WenmuZhou 已提交
425
    pbar.close()
W
WenmuZhou 已提交
426
    model.train()
L
LDOUBLEV 已提交
427 428
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
429

T
tink2123 已提交
430

B
Bin Lu 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


480
def preprocess(is_train=False):
L
licx 已提交
481
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
482
    profiler_options = FLAGS.profiler_options
L
licx 已提交
483 484
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
L
LDOUBLEV 已提交
485 486
    profile_dic = {"profiler_options": FLAGS.profiler_options}
    merge_config(profile_dic)
L
licx 已提交
487

W
WenmuZhou 已提交
488 489 490 491 492 493 494 495 496 497 498
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
L
licx 已提交
499 500 501 502 503

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

W
WenmuZhou 已提交
504 505
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
506
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
507
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
L
LDOUBLEV 已提交
508
        'SEED', 'SDMGR'
W
WenmuZhou 已提交
509
    ]
W
WenmuZhou 已提交
510 511 512 513 514
    windows_not_support_list = ['PSE']
    if platform.system() == "Windows" and alg in windows_not_support_list:
        logger.warning('{} is not support in Windows now'.format(
            windows_not_support_list))
        sys.exit()
L
licx 已提交
515

W
WenmuZhou 已提交
516 517
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
D
dyning 已提交
518

D
dyning 已提交
519
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
520

D
dyning 已提交
521 522
    if config['Global']['use_visualdl']:
        from visualdl import LogWriter
L
fix bug  
LDOUBLEV 已提交
523
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
524 525 526 527 528 529 530 531 532
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer