program.py 19.4 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27 28 29
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
30 31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
32
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
33
from ppocr.utils.logging import get_logger
L
LDOUBLEV 已提交
34
from ppocr.utils import profiler
D
dyning 已提交
35
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
36

D
dyning 已提交
37

L
LDOUBLEV 已提交
38 39 40 41 42 43 44
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
45 46 47 48 49
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
50 51
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
52
        )
L
LDOUBLEV 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
81 82
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
83 84


85
def merge_config(config, opts):
L
LDOUBLEV 已提交
86 87 88 89 90 91
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
92
    for key, value in opts.items():
L
LDOUBLEV 已提交
93
        if "." not in key:
94 95
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
96
            else:
97
                config[key] = value
L
LDOUBLEV 已提交
98 99
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
100
            assert (
101
                sub_keys[0] in config
102 103
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
104 105
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
106 107 108 109 110
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
111
    return config
L
LDOUBLEV 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
W
WenmuZhou 已提交
126
        if use_gpu and not paddle.is_compiled_with_cuda():
W
WenmuZhou 已提交
127
            print(err)
L
LDOUBLEV 已提交
128 129 130 131 132
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
133
def train(config,
D
dyning 已提交
134 135 136
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
137 138 139 140 141 142 143 144
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
S
stephon 已提交
145 146
          vdl_writer=None,
          scaler=None):
W
WenmuZhou 已提交
147 148
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
L
LDOUBLEV 已提交
149 150 151 152
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
153
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
154

D
dyning 已提交
155
    global_step = 0
156 157
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
158 159 160 161
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
162 163
        if len(valid_dataloader) == 0:
            logger.info(
164 165
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
166 167
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
168
        logger.info(
169 170
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
171
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
172 173
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
174 175
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
176 177 178 179
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
180
    model_average = False
W
WenmuZhou 已提交
181 182
    model.train()

T
tink2123 已提交
183
    use_srn = config['Architecture']['algorithm'] == "SRN"
T
tink2123 已提交
184
    extra_input = config['Architecture'][
L
LDOUBLEV 已提交
185
        'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
186
    try:
L
fix bug  
LDOUBLEV 已提交
187
        model_type = config['Architecture']['model_type']
188
    except:
L
fix bug  
LDOUBLEV 已提交
189
        model_type = None
T
tink2123 已提交
190
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
191

192 193 194 195
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
196 197
    train_reader_cost = 0.0
    train_batch_cost = 0.0
198
    reader_start = time.time()
199
    eta_meter = AverageMeter()
200 201 202

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
203

T
tink2123 已提交
204
    for epoch in range(start_epoch, epoch_num + 1):
205 206 207 208 209
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
210
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
211
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
212
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
213
            if idx >= max_iter:
W
WenmuZhou 已提交
214 215 216
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
217
            if use_srn:
T
tink2123 已提交
218
                model_average = True
S
stephon 已提交
219 220 221 222 223 224 225 226

            # use amp
            if scaler:
                with paddle.amp.auto_cast():
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
                    else:
                        preds = model(images)
T
tink2123 已提交
227
            else:
S
stephon 已提交
228 229
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
230
                elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
231
                    preds = model(batch)
S
stephon 已提交
232 233
                else:
                    preds = model(images)
234

W
WenmuZhou 已提交
235 236
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
S
stephon 已提交
237 238 239 240 241 242 243 244

            if scaler:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
            else:
                avg_loss.backward()
                optimizer.step()
W
WenmuZhou 已提交
245
            optimizer.clear_grad()
W
WenmuZhou 已提交
246

247 248 249
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
250
            global_step += 1
文幕地方's avatar
文幕地方 已提交
251
            total_samples += len(images)
W
WenmuZhou 已提交
252

D
dyning 已提交
253 254
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
255 256 257 258 259 260

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

L
LDOUBLEV 已提交
261
            if cal_metric_during_train:  # only rec and cls need
W
WenmuZhou 已提交
262
                batch = [item.numpy() for item in batch]
L
LDOUBLEV 已提交
263
                if model_type in ['table', 'kie']:
M
MissPenguin 已提交
264 265 266 267
                    eval_class(preds, batch)
                else:
                    post_result = post_process_class(preds, batch[1])
                    eval_class(post_result, batch)
littletomatodonkey's avatar
fix doc  
littletomatodonkey 已提交
268 269
                metric = eval_class.get_metric()
                train_stats.update(metric)
W
WenmuZhou 已提交
270 271 272 273 274 275

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

276 277 278
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
279
                logs = train_stats.log()
L
LDOUBLEV 已提交
280

281 282 283 284 285
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
                       '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
L
LDOUBLEV 已提交
286
                       'samples/s: {:.5f}, eta: {}'.format(
287 288 289 290 291
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
292
                logger.info(strs)
293

文幕地方's avatar
文幕地方 已提交
294
                total_samples = 0
295 296
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
297 298
            # eval
            if global_step > start_eval_step and \
299 300
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
301 302 303 304 305 306 307
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
308 309 310 311 312
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
313
                    model_type,
T
tink2123 已提交
314
                    extra_input=extra_input)
L
LDOUBLEV 已提交
315 316 317
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
318 319 320

                # logger metric
                if vdl_writer is not None:
L
LDOUBLEV 已提交
321
                    for k, v in cur_metric.items():
W
WenmuZhou 已提交
322 323
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
L
LDOUBLEV 已提交
324 325
                                                  cur_metric[k], global_step)
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
326
                        main_indicator]:
L
LDOUBLEV 已提交
327
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
328 329 330 331 332 333
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
334
                        config,
W
WenmuZhou 已提交
335 336 337
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
338 339
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
340
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
341 342 343 344 345 346 347 348
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
349

文幕地方's avatar
文幕地方 已提交
350
            reader_start = time.time()
W
WenmuZhou 已提交
351 352 353 354 355 356
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
357
                config,
W
WenmuZhou 已提交
358 359 360
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
361 362
                epoch=epoch,
                global_step=global_step)
W
WenmuZhou 已提交
363 364 365 366 367 368
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
369
                config,
W
WenmuZhou 已提交
370 371 372
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
373 374
                epoch=epoch,
                global_step=global_step)
L
LDOUBLEV 已提交
375
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
376 377 378 379
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
L
LDOUBLEV 已提交
380 381 382
    return


M
refine  
MissPenguin 已提交
383 384 385 386
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
387
         model_type=None,
T
tink2123 已提交
388
         extra_input=False):
W
WenmuZhou 已提交
389 390 391 392
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
393 394 395 396 397
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
398 399
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
W
WenmuZhou 已提交
400
        for idx, batch in enumerate(valid_dataloader):
401
            if idx >= max_iter:
W
WenmuZhou 已提交
402
                break
W
fix bug  
WenmuZhou 已提交
403
            images = batch[0]
W
WenmuZhou 已提交
404
            start = time.time()
T
tink2123 已提交
405
            if model_type == 'table' or extra_input:
M
refine  
MissPenguin 已提交
406
                preds = model(images, data=batch[1:])
407
            elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
408
                preds = model(batch)
X
xiaoting 已提交
409
            else:
L
LDOUBLEV 已提交
410
                preds = model(images)
411 412 413 414 415 416 417

            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
418 419 420
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
L
LDOUBLEV 已提交
421
            if model_type in ['table', 'kie']:
422 423 424 425
                eval_class(preds, batch_numpy)
            elif model_type in ['vqa']:
                post_result = post_process_class(preds, batch_numpy)
                eval_class(post_result, batch_numpy)
M
MissPenguin 已提交
426
            else:
427 428
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
429

W
fix bug  
WenmuZhou 已提交
430
            pbar.update(1)
W
WenmuZhou 已提交
431
            total_frame += len(images)
L
LDOUBLEV 已提交
432 433
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
434

W
fix bug  
WenmuZhou 已提交
435
    pbar.close()
W
WenmuZhou 已提交
436
    model.train()
L
LDOUBLEV 已提交
437 438
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
439

T
tink2123 已提交
440

B
Bin Lu 已提交
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


490
def preprocess(is_train=False):
L
licx 已提交
491
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
492
    profiler_options = FLAGS.profiler_options
L
licx 已提交
493
    config = load_config(FLAGS.config)
494
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
495
    profile_dic = {"profiler_options": FLAGS.profiler_options}
496
    config = merge_config(config, profile_dic)
L
licx 已提交
497

W
WenmuZhou 已提交
498 499 500 501 502 503 504 505 506 507 508
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
L
licx 已提交
509 510 511 512 513

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

W
WenmuZhou 已提交
514 515
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
516
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
517
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
518
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
W
WenmuZhou 已提交
519
    ]
L
licx 已提交
520

W
WenmuZhou 已提交
521 522
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
D
dyning 已提交
523

D
dyning 已提交
524
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
525

littletomatodonkey's avatar
littletomatodonkey 已提交
526
    if config['Global']['use_visualdl'] and dist.get_rank() == 0:
D
dyning 已提交
527
        from visualdl import LogWriter
L
fix bug  
LDOUBLEV 已提交
528
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
529 530 531 532 533 534 535 536 537
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer