program.py 20.9 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27 28 29
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
30 31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
32
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
33
from ppocr.utils.logging import get_logger
34
from ppocr.utils.loggers import VDLLogger, WandbLogger
L
LDOUBLEV 已提交
35
from ppocr.utils import profiler
D
dyning 已提交
36
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
37

D
dyning 已提交
38

L
LDOUBLEV 已提交
39 40 41 42 43 44 45
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
46 47 48 49 50
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
51 52
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
53
        )
L
LDOUBLEV 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
82 83
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
84 85


86
def merge_config(config, opts):
L
LDOUBLEV 已提交
87 88 89 90 91 92
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
93
    for key, value in opts.items():
L
LDOUBLEV 已提交
94
        if "." not in key:
95 96
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
97
            else:
98
                config[key] = value
L
LDOUBLEV 已提交
99 100
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
101
            assert (
102
                sub_keys[0] in config
103 104
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
105 106
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
107 108 109 110 111
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
112
    return config
L
LDOUBLEV 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
W
WenmuZhou 已提交
127
        if use_gpu and not paddle.is_compiled_with_cuda():
W
WenmuZhou 已提交
128
            print(err)
L
LDOUBLEV 已提交
129 130 131 132 133
            sys.exit(1)
    except Exception as e:
        pass


134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
def check_xpu(use_xpu):
    """
    Log error and exit when set use_xpu=true in paddlepaddle
    cpu/gpu version.
    """
    err = "Config use_xpu cannot be set as true while you are " \
          "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
          "\t2. Set use_xpu as false in config file to run " \
          "model on CPU/GPU"

    try:
        if use_xpu and not paddle.is_compiled_with_xpu():
            print(err)
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
153
def train(config,
D
dyning 已提交
154 155 156
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
157 158 159 160 161 162 163 164
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
165
          log_writer=None,
S
stephon 已提交
166
          scaler=None):
W
WenmuZhou 已提交
167 168
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
169
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
L
LDOUBLEV 已提交
170 171 172 173
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
174
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
175

D
dyning 已提交
176
    global_step = 0
177 178
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
179 180 181 182
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
183 184
        if len(valid_dataloader) == 0:
            logger.info(
185 186
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
187 188
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
189
        logger.info(
190 191
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
192
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
193 194
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
195 196
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
197 198 199 200
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
201
    model_average = False
W
WenmuZhou 已提交
202 203
    model.train()

T
tink2123 已提交
204
    use_srn = config['Architecture']['algorithm'] == "SRN"
T
tink2123 已提交
205
    extra_input = config['Architecture'][
L
LDOUBLEV 已提交
206
        'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
207
    try:
L
fix bug  
LDOUBLEV 已提交
208
        model_type = config['Architecture']['model_type']
209
    except:
L
fix bug  
LDOUBLEV 已提交
210
        model_type = None
T
tink2123 已提交
211
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
212

213 214 215 216
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
217 218
    train_reader_cost = 0.0
    train_batch_cost = 0.0
219
    reader_start = time.time()
220
    eta_meter = AverageMeter()
221 222 223

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
224

T
tink2123 已提交
225
    for epoch in range(start_epoch, epoch_num + 1):
226 227 228 229 230
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
231
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
232
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
233
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
234
            if idx >= max_iter:
W
WenmuZhou 已提交
235 236 237
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
238
            if use_srn:
T
tink2123 已提交
239
                model_average = True
S
stephon 已提交
240 241 242 243 244 245 246 247

            # use amp
            if scaler:
                with paddle.amp.auto_cast():
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
                    else:
                        preds = model(images)
T
tink2123 已提交
248
            else:
S
stephon 已提交
249 250
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
251
                elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
252
                    preds = model(batch)
S
stephon 已提交
253 254
                else:
                    preds = model(images)
255

W
WenmuZhou 已提交
256 257
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
S
stephon 已提交
258 259 260 261 262 263 264 265

            if scaler:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
            else:
                avg_loss.backward()
                optimizer.step()
W
WenmuZhou 已提交
266
            optimizer.clear_grad()
W
WenmuZhou 已提交
267

268 269 270 271 272 273 274 275 276 277
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
                if model_type in ['table', 'kie']:
                    eval_class(preds, batch)
                else:
                    post_result = post_process_class(preds, batch[1])
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

278 279 280
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
281
            global_step += 1
文幕地方's avatar
文幕地方 已提交
282
            total_samples += len(images)
W
WenmuZhou 已提交
283

D
dyning 已提交
284 285
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
286 287 288 289 290 291

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

292 293
            if log_writer is not None and dist.get_rank() == 0:
                log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
W
WenmuZhou 已提交
294

295 296 297
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
298
                logs = train_stats.log()
L
LDOUBLEV 已提交
299

300 301 302 303 304
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
                       '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
L
LDOUBLEV 已提交
305
                       'ips: {:.5f} samples/s, eta: {}'.format(
306 307 308 309 310
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
311
                logger.info(strs)
312

文幕地方's avatar
文幕地方 已提交
313
                total_samples = 0
314 315
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
316 317
            # eval
            if global_step > start_eval_step and \
318 319
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
320 321 322 323 324 325 326
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
327 328 329 330 331
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
332
                    model_type,
T
tink2123 已提交
333
                    extra_input=extra_input)
L
LDOUBLEV 已提交
334 335 336
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
337 338

                # logger metric
339 340 341
                if log_writer is not None:
                    log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)

L
LDOUBLEV 已提交
342
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
343
                        main_indicator]:
L
LDOUBLEV 已提交
344
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
345 346 347 348 349 350
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
351
                        config,
W
WenmuZhou 已提交
352 353 354
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
355 356
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
357
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
358 359 360 361
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
362 363 364 365 366 367 368
                if log_writer is not None:
                    log_writer.log_metrics(metrics={
                        "best_{}".format(main_indicator): best_model_dict[main_indicator]
                    }, prefix="EVAL", step=global_step)

                    if isinstance(log_writer, WandbLogger):
                        log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
369

文幕地方's avatar
文幕地方 已提交
370
            reader_start = time.time()
W
WenmuZhou 已提交
371 372 373 374 375 376
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
377
                config,
W
WenmuZhou 已提交
378 379 380
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
381 382
                epoch=epoch,
                global_step=global_step)
383 384 385 386

            if isinstance(log_writer, WandbLogger):
                log_writer.log_model(is_best=False, prefix="latest")

W
WenmuZhou 已提交
387 388 389 390 391 392
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
393
                config,
W
WenmuZhou 已提交
394 395 396
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
397 398
                epoch=epoch,
                global_step=global_step)
399 400 401 402
            
            if isinstance(log_writer, WandbLogger):
                log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))

L
LDOUBLEV 已提交
403
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
404 405
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
406 407
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
L
LDOUBLEV 已提交
408 409 410
    return


M
refine  
MissPenguin 已提交
411 412 413 414
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
415
         model_type=None,
T
tink2123 已提交
416
         extra_input=False):
W
WenmuZhou 已提交
417 418 419 420
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
421 422 423 424 425
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
426 427
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
W
WenmuZhou 已提交
428
        for idx, batch in enumerate(valid_dataloader):
429
            if idx >= max_iter:
W
WenmuZhou 已提交
430
                break
W
fix bug  
WenmuZhou 已提交
431
            images = batch[0]
W
WenmuZhou 已提交
432
            start = time.time()
T
tink2123 已提交
433
            if model_type == 'table' or extra_input:
M
refine  
MissPenguin 已提交
434
                preds = model(images, data=batch[1:])
435
            elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
436
                preds = model(batch)
X
xiaoting 已提交
437
            else:
L
LDOUBLEV 已提交
438
                preds = model(images)
439 440 441 442 443 444 445

            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
446 447 448
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
L
LDOUBLEV 已提交
449
            if model_type in ['table', 'kie']:
450 451 452 453
                eval_class(preds, batch_numpy)
            elif model_type in ['vqa']:
                post_result = post_process_class(preds, batch_numpy)
                eval_class(post_result, batch_numpy)
M
MissPenguin 已提交
454
            else:
455 456
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
457

W
fix bug  
WenmuZhou 已提交
458
            pbar.update(1)
W
WenmuZhou 已提交
459
            total_frame += len(images)
L
LDOUBLEV 已提交
460 461
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
462

W
fix bug  
WenmuZhou 已提交
463
    pbar.close()
W
WenmuZhou 已提交
464
    model.train()
L
LDOUBLEV 已提交
465 466
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
467

T
tink2123 已提交
468

B
Bin Lu 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


518
def preprocess(is_train=False):
L
licx 已提交
519
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
520
    profiler_options = FLAGS.profiler_options
L
licx 已提交
521
    config = load_config(FLAGS.config)
522
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
523
    profile_dic = {"profiler_options": FLAGS.profiler_options}
524
    config = merge_config(config, profile_dic)
L
licx 已提交
525

W
WenmuZhou 已提交
526 527 528 529 530 531 532 533 534 535
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
Z
zhoujun 已提交
536
    logger = get_logger(log_file=log_file)
L
licx 已提交
537 538 539 540 541

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

542 543 544 545 546 547
    # check if set use_xpu=True in paddlepaddle cpu/gpu version
    use_xpu = False
    if 'use_xpu' in config['Global']:
        use_xpu = config['Global']['use_xpu']
    check_xpu(use_xpu)

W
WenmuZhou 已提交
548 549
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
550
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
551
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
552
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE'
W
WenmuZhou 已提交
553
    ]
L
licx 已提交
554

555 556 557 558 559
    device = 'cpu'
    if use_gpu:
        device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
    if use_xpu:
        device = 'xpu'
W
WenmuZhou 已提交
560
    device = paddle.set_device(device)
D
dyning 已提交
561

D
dyning 已提交
562
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
563

564
    if "use_visualdl" in config['Global'] and config['Global']['use_visualdl'] and dist.get_rank() == 0:
L
fix bug  
LDOUBLEV 已提交
565
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
566
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
567 568 569 570 571 572 573 574 575 576
        log_writer = VDLLogger(save_model_dir)
    elif ("use_wandb" in config['Global'] and config['Global']['use_wandb']) or "wandb" in config:
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
D
dyning 已提交
577
    else:
578
        log_writer = None
D
dyning 已提交
579 580 581
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
582
    return config, device, logger, log_writer