program.py 22.7 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27 28 29
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
30 31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
32
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
33
from ppocr.utils.logging import get_logger
34
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
L
LDOUBLEV 已提交
35
from ppocr.utils import profiler
D
dyning 已提交
36
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
37

D
dyning 已提交
38

L
LDOUBLEV 已提交
39 40 41 42 43 44 45
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
46 47 48 49 50
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
51 52
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
53
        )
L
LDOUBLEV 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
82 83
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
84 85


86
def merge_config(config, opts):
L
LDOUBLEV 已提交
87 88 89 90 91 92
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
93
    for key, value in opts.items():
L
LDOUBLEV 已提交
94
        if "." not in key:
95 96
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
97
            else:
98
                config[key] = value
L
LDOUBLEV 已提交
99 100
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
101
            assert (
102
                sub_keys[0] in config
103 104
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
105 106
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
107 108 109 110 111
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
112
    return config
L
LDOUBLEV 已提交
113 114


X
xiaoting 已提交
115
def check_device(use_gpu, use_xpu=False):
L
LDOUBLEV 已提交
116 117 118 119
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
X
xiaoting 已提交
120 121 122 123
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
L
LDOUBLEV 已提交
124 125 126
          "model on CPU"

    try:
X
xiaoting 已提交
127 128
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
W
WenmuZhou 已提交
129
        if use_gpu and not paddle.is_compiled_with_cuda():
X
xiaoting 已提交
130 131 132 133
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
L
LDOUBLEV 已提交
134 135 136 137 138
            sys.exit(1)
    except Exception as e:
        pass


139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
def check_xpu(use_xpu):
    """
    Log error and exit when set use_xpu=true in paddlepaddle
    cpu/gpu version.
    """
    err = "Config use_xpu cannot be set as true while you are " \
          "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
          "\t2. Set use_xpu as false in config file to run " \
          "model on CPU/GPU"

    try:
        if use_xpu and not paddle.is_compiled_with_xpu():
            print(err)
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
158
def train(config,
D
dyning 已提交
159 160 161
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
162 163 164 165 166 167 168 169
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
170
          log_writer=None,
S
stephon 已提交
171
          scaler=None):
W
WenmuZhou 已提交
172 173
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
174
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
L
LDOUBLEV 已提交
175 176 177 178
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
179
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
180

D
dyning 已提交
181
    global_step = 0
182 183
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
184 185 186 187
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
188 189
        if len(valid_dataloader) == 0:
            logger.info(
190 191
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
192 193
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
194
        logger.info(
195 196
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
197
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
198 199
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
200 201
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
202 203 204 205
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
206
    model_average = False
W
WenmuZhou 已提交
207 208
    model.train()

T
tink2123 已提交
209
    use_srn = config['Architecture']['algorithm'] == "SRN"
xuyang2233's avatar
xuyang2233 已提交
210
    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"]
A
andyjpaddle 已提交
211
    extra_input = False
A
andyjpaddle 已提交
212
    if config['Architecture']['algorithm'] == 'Distillation':
A
andyjpaddle 已提交
213 214 215
        for key in config['Architecture']["Models"]:
            extra_input = extra_input or config['Architecture']['Models'][key][
                'algorithm'] in extra_input_models
A
andyjpaddle 已提交
216 217
    else:
        extra_input = config['Architecture']['algorithm'] in extra_input_models
218
    try:
L
fix bug  
LDOUBLEV 已提交
219
        model_type = config['Architecture']['model_type']
220
    except:
L
fix bug  
LDOUBLEV 已提交
221
        model_type = None
A
andyjpaddle 已提交
222

T
tink2123 已提交
223
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
224

225 226 227 228
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
229 230
    train_reader_cost = 0.0
    train_batch_cost = 0.0
231
    reader_start = time.time()
232
    eta_meter = AverageMeter()
233 234 235

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
236

T
tink2123 已提交
237
    for epoch in range(start_epoch, epoch_num + 1):
238 239 240 241 242
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
243
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
244
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
245
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
246
            if idx >= max_iter:
W
WenmuZhou 已提交
247 248 249
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
250
            if use_srn:
T
tink2123 已提交
251
                model_average = True
S
stephon 已提交
252 253 254 255 256 257

            # use amp
            if scaler:
                with paddle.amp.auto_cast():
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
A
andyjpaddle 已提交
258 259
                    elif model_type in ["kie", 'vqa']:
                        preds = model(batch)
S
stephon 已提交
260 261
                    else:
                        preds = model(images)
T
tink2123 已提交
262
            else:
S
stephon 已提交
263 264
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
265
                elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
266
                    preds = model(batch)
S
stephon 已提交
267 268
                else:
                    preds = model(images)
269

W
WenmuZhou 已提交
270 271
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
S
stephon 已提交
272 273 274 275 276 277 278 279

            if scaler:
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
            else:
                avg_loss.backward()
                optimizer.step()
W
WenmuZhou 已提交
280
            optimizer.clear_grad()
W
WenmuZhou 已提交
281

282 283
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
文幕地方's avatar
文幕地方 已提交
284
                if model_type in ['kie']:
285
                    eval_class(preds, batch)
文幕地方's avatar
文幕地方 已提交
286 287 288
                elif model_type in ['table']:
                    post_result = post_process_class(preds, batch)
                    eval_class(post_result, batch)
289
                else:
A
andyjpaddle 已提交
290 291 292 293 294 295
                    if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
                                                  ]:  # for multi head loss
                        post_result = post_process_class(
                            preds['ctc'], batch[1])  # for CTC head out
                    else:
                        post_result = post_process_class(preds, batch[1])
296 297 298 299
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

300 301 302
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
303
            global_step += 1
文幕地方's avatar
文幕地方 已提交
304
            total_samples += len(images)
W
WenmuZhou 已提交
305

D
dyning 已提交
306 307
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
308 309 310 311 312 313

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

314
            if log_writer is not None and dist.get_rank() == 0:
文幕地方's avatar
文幕地方 已提交
315 316
                log_writer.log_metrics(
                    metrics=train_stats.get(), prefix="TRAIN", step=global_step)
W
WenmuZhou 已提交
317

318 319 320
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
321
                logs = train_stats.log()
L
LDOUBLEV 已提交
322

323 324 325 326 327
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
                       '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
L
LDOUBLEV 已提交
328
                       'ips: {:.5f} samples/s, eta: {}'.format(
329 330 331 332 333
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
334
                logger.info(strs)
335

文幕地方's avatar
文幕地方 已提交
336
                total_samples = 0
337 338
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
339 340
            # eval
            if global_step > start_eval_step and \
341 342
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
343 344 345 346 347 348 349
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
350 351 352 353 354
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
355
                    model_type,
T
tink2123 已提交
356
                    extra_input=extra_input)
L
LDOUBLEV 已提交
357 358 359
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
360 361

                # logger metric
362
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
363 364
                    log_writer.log_metrics(
                        metrics=cur_metric, prefix="EVAL", step=global_step)
365

L
LDOUBLEV 已提交
366
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
367
                        main_indicator]:
L
LDOUBLEV 已提交
368
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
369 370 371 372 373 374
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
375
                        config,
W
WenmuZhou 已提交
376 377 378
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
379 380
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
381
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
382 383 384 385
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
386
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
387 388 389 390 391 392 393 394 395 396 397 398
                    log_writer.log_metrics(
                        metrics={
                            "best_{}".format(main_indicator):
                            best_model_dict[main_indicator]
                        },
                        prefix="EVAL",
                        step=global_step)

                    log_writer.log_model(
                        is_best=True,
                        prefix="best_accuracy",
                        metadata=best_model_dict)
399

文幕地方's avatar
文幕地方 已提交
400
            reader_start = time.time()
W
WenmuZhou 已提交
401 402 403 404 405 406
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
407
                config,
W
WenmuZhou 已提交
408 409 410
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
411 412
                epoch=epoch,
                global_step=global_step)
413

414 415
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix="latest")
416

W
WenmuZhou 已提交
417 418 419 420 421 422
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
423
                config,
W
WenmuZhou 已提交
424 425 426
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
427 428
                epoch=epoch,
                global_step=global_step)
429
            if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
430 431
                log_writer.log_model(
                    is_best=False, prefix='iter_epoch_{}'.format(epoch))
432

L
LDOUBLEV 已提交
433
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
434 435
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
436 437
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
L
LDOUBLEV 已提交
438 439 440
    return


M
refine  
MissPenguin 已提交
441 442 443 444
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
445
         model_type=None,
T
tink2123 已提交
446
         extra_input=False):
W
WenmuZhou 已提交
447 448 449 450
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
451 452 453 454 455
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
456 457
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
W
WenmuZhou 已提交
458
        for idx, batch in enumerate(valid_dataloader):
459
            if idx >= max_iter:
W
WenmuZhou 已提交
460
                break
W
fix bug  
WenmuZhou 已提交
461
            images = batch[0]
W
WenmuZhou 已提交
462
            start = time.time()
T
tink2123 已提交
463
            if model_type == 'table' or extra_input:
M
refine  
MissPenguin 已提交
464
                preds = model(images, data=batch[1:])
465
            elif model_type in ["kie", 'vqa']:
L
LDOUBLEV 已提交
466
                preds = model(batch)
X
xiaoting 已提交
467
            else:
L
LDOUBLEV 已提交
468
                preds = model(images)
469 470 471 472 473 474
            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
475 476 477
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
文幕地方's avatar
文幕地方 已提交
478
            if model_type in ['kie']:
479
                eval_class(preds, batch_numpy)
文幕地方's avatar
文幕地方 已提交
480
            elif model_type in ['table', 'vqa']:
481 482
                post_result = post_process_class(preds, batch_numpy)
                eval_class(post_result, batch_numpy)
M
MissPenguin 已提交
483
            else:
484 485
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
486

W
fix bug  
WenmuZhou 已提交
487
            pbar.update(1)
W
WenmuZhou 已提交
488
            total_frame += len(images)
L
LDOUBLEV 已提交
489 490
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
491

W
fix bug  
WenmuZhou 已提交
492
    pbar.close()
W
WenmuZhou 已提交
493
    model.train()
L
LDOUBLEV 已提交
494 495
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
496

T
tink2123 已提交
497

B
Bin Lu 已提交
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


547
def preprocess(is_train=False):
L
licx 已提交
548
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
549
    profiler_options = FLAGS.profiler_options
L
licx 已提交
550
    config = load_config(FLAGS.config)
551
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
552
    profile_dic = {"profiler_options": FLAGS.profiler_options}
553
    config = merge_config(config, profile_dic)
L
licx 已提交
554

W
WenmuZhou 已提交
555 556 557 558 559 560 561 562 563 564
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
Z
zhoujun 已提交
565
    logger = get_logger(log_file=log_file)
L
licx 已提交
566 567 568

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
X
xiaoting 已提交
569
    use_xpu = config['Global'].get('use_xpu', False)
L
licx 已提交
570

571 572 573 574 575 576
    # check if set use_xpu=True in paddlepaddle cpu/gpu version
    use_xpu = False
    if 'use_xpu' in config['Global']:
        use_xpu = config['Global']['use_xpu']
    check_xpu(use_xpu)

W
WenmuZhou 已提交
577 578
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
579
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
580
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
W
wangjingyeye 已提交
581
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
xuyang2233's avatar
xuyang2233 已提交
582
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'RobustScanner'
W
WenmuZhou 已提交
583
    ]
L
licx 已提交
584

585
    if use_xpu:
X
xiaoting 已提交
586 587 588 589 590 591
        device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
    else:
        device = 'gpu:{}'.format(dist.ParallelEnv()
                                 .dev_id) if use_gpu else 'cpu'
    check_device(use_gpu, use_xpu)

W
WenmuZhou 已提交
592
    device = paddle.set_device(device)
D
dyning 已提交
593

D
dyning 已提交
594
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
595

596 597
    loggers = []

598
    if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
L
fix bug  
LDOUBLEV 已提交
599
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
600
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
601
        log_writer = VDLLogger(save_model_dir)
602
        loggers.append(log_writer)
文幕地方's avatar
文幕地方 已提交
603 604
    if ('use_wandb' in config['Global'] and
            config['Global']['use_wandb']) or 'wandb' in config:
605 606 607 608 609 610 611 612
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
613
        loggers.append(log_writer)
D
dyning 已提交
614
    else:
615
        log_writer = None
D
dyning 已提交
616
    print_dict(config, logger)
617 618 619 620 621 622

    if loggers:
        log_writer = Loggers(loggers)
    else:
        log_writer = None

D
dyning 已提交
623 624
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
625
    return config, device, logger, log_writer