program.py 25.1 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27
import paddle
import paddle.distributed as dist
from tqdm import tqdm
X
xiaoting 已提交
28 29
import cv2
import numpy as np
W
WenmuZhou 已提交
30 31
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
32 33
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
34
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
35
from ppocr.utils.logging import get_logger
36
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
L
LDOUBLEV 已提交
37
from ppocr.utils import profiler
D
dyning 已提交
38
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
39

D
dyning 已提交
40

L
LDOUBLEV 已提交
41 42 43 44 45 46 47
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
48 49 50 51 52
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
53 54
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
55
        )
L
LDOUBLEV 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
84 85
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
86 87


88
def merge_config(config, opts):
L
LDOUBLEV 已提交
89 90 91 92 93 94
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
95
    for key, value in opts.items():
L
LDOUBLEV 已提交
96
        if "." not in key:
97 98
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
99
            else:
100
                config[key] = value
L
LDOUBLEV 已提交
101 102
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
103
            assert (
104
                sub_keys[0] in config
105 106
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
107 108
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
109 110 111 112 113
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
114
    return config
L
LDOUBLEV 已提交
115 116


X
xiaoting 已提交
117
def check_device(use_gpu, use_xpu=False):
L
LDOUBLEV 已提交
118 119 120 121
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
X
xiaoting 已提交
122 123 124 125
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
L
LDOUBLEV 已提交
126 127 128
          "model on CPU"

    try:
X
xiaoting 已提交
129 130
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
W
WenmuZhou 已提交
131
        if use_gpu and not paddle.is_compiled_with_cuda():
X
xiaoting 已提交
132 133 134 135
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
L
LDOUBLEV 已提交
136 137 138 139 140
            sys.exit(1)
    except Exception as e:
        pass


141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
def check_xpu(use_xpu):
    """
    Log error and exit when set use_xpu=true in paddlepaddle
    cpu/gpu version.
    """
    err = "Config use_xpu cannot be set as true while you are " \
          "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
          "\t2. Set use_xpu as false in config file to run " \
          "model on CPU/GPU"

    try:
        if use_xpu and not paddle.is_compiled_with_xpu():
            print(err)
            sys.exit(1)
    except Exception as e:
        pass

文幕地方's avatar
文幕地方 已提交
159

文幕地方's avatar
文幕地方 已提交
160 161 162 163 164
def to_float32(preds):
    if isinstance(preds, dict):
        for k in preds:
            if isinstance(preds[k], dict) or isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
165 166
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
167 168 169 170 171 172
    elif isinstance(preds, list):
        for k in range(len(preds)):
            if isinstance(preds[k], dict):
                preds[k] = to_float32(preds[k])
            elif isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
173 174 175
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
    elif isinstance(preds, paddle.Tensor):
176
        preds = preds.astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
177
    return preds
178

文幕地方's avatar
文幕地方 已提交
179

W
WenmuZhou 已提交
180
def train(config,
D
dyning 已提交
181 182 183
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
184 185 186 187 188 189 190 191
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
192
          log_writer=None,
文幕地方's avatar
文幕地方 已提交
193 194
          scaler=None,
          amp_level='O2'):
W
WenmuZhou 已提交
195 196
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
197
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
L
LDOUBLEV 已提交
198 199 200 201
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
202
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
203

D
dyning 已提交
204
    global_step = 0
205 206
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
207 208 209 210
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
211 212
        if len(valid_dataloader) == 0:
            logger.info(
213 214
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
215 216
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
217
        logger.info(
218 219
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
220
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
221 222
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
223 224
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
225 226 227 228
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
229
    model_average = False
W
WenmuZhou 已提交
230 231
    model.train()

T
tink2123 已提交
232
    use_srn = config['Architecture']['algorithm'] == "SRN"
A
andyjpaddle 已提交
233
    extra_input_models = [
X
xiaoting 已提交
234 235
        "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
        "RobustScanner"
A
andyjpaddle 已提交
236
    ]
A
andyjpaddle 已提交
237
    extra_input = False
A
andyjpaddle 已提交
238
    if config['Architecture']['algorithm'] == 'Distillation':
A
andyjpaddle 已提交
239 240 241
        for key in config['Architecture']["Models"]:
            extra_input = extra_input or config['Architecture']['Models'][key][
                'algorithm'] in extra_input_models
A
andyjpaddle 已提交
242 243
    else:
        extra_input = config['Architecture']['algorithm'] in extra_input_models
244
    try:
L
fix bug  
LDOUBLEV 已提交
245
        model_type = config['Architecture']['model_type']
246
    except:
L
fix bug  
LDOUBLEV 已提交
247
        model_type = None
A
andyjpaddle 已提交
248

T
tink2123 已提交
249
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
250

251 252 253 254
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
255 256
    train_reader_cost = 0.0
    train_batch_cost = 0.0
257
    reader_start = time.time()
258
    eta_meter = AverageMeter()
259 260 261

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
262

T
tink2123 已提交
263
    for epoch in range(start_epoch, epoch_num + 1):
264 265 266 267 268
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
X
xiaoting 已提交
269

W
WenmuZhou 已提交
270
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
271
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
272
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
273
            if idx >= max_iter:
W
WenmuZhou 已提交
274 275 276
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
277
            if use_srn:
T
tink2123 已提交
278
                model_average = True
S
stephon 已提交
279 280
            # use amp
            if scaler:
281 282 283 284
                custom_black_list = config['Global'].get(
                    'amp_custom_black_list', [])
                with paddle.amp.auto_cast(
                        level=amp_level, custom_black_list=custom_black_list):
S
stephon 已提交
285 286
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
287
                    elif model_type in ["kie"]:
A
andyjpaddle 已提交
288
                        preds = model(batch)
S
stephon 已提交
289 290
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
291 292 293 294 295 296
                preds = to_float32(preds)
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
T
tink2123 已提交
297
            else:
S
stephon 已提交
298 299
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
300
                elif model_type in ["kie", 'sr']:
L
LDOUBLEV 已提交
301
                    preds = model(batch)
S
stephon 已提交
302 303
                else:
                    preds = model(images)
文幕地方's avatar
文幕地方 已提交
304 305
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
S
stephon 已提交
306 307
                avg_loss.backward()
                optimizer.step()
X
xiaoting 已提交
308

W
WenmuZhou 已提交
309
            optimizer.clear_grad()
W
WenmuZhou 已提交
310

311 312
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
X
xiaoting 已提交
313
                if model_type in ['kie', 'sr']:
314
                    eval_class(preds, batch)
文幕地方's avatar
文幕地方 已提交
315 316 317
                elif model_type in ['table']:
                    post_result = post_process_class(preds, batch)
                    eval_class(post_result, batch)
318
                else:
A
andyjpaddle 已提交
319 320 321 322
                    if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
                                                  ]:  # for multi head loss
                        post_result = post_process_class(
                            preds['ctc'], batch[1])  # for CTC head out
A
andyjpaddle 已提交
323 324 325
                    elif config['Loss']['name'] in ['VLLoss']:
                        post_result = post_process_class(preds, batch[1],
                                                         batch[-1])
A
andyjpaddle 已提交
326 327
                    else:
                        post_result = post_process_class(preds, batch[1])
328 329 330 331
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

332 333 334
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
335
            global_step += 1
文幕地方's avatar
文幕地方 已提交
336
            total_samples += len(images)
W
WenmuZhou 已提交
337

D
dyning 已提交
338 339
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
340 341 342 343 344 345

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

346
            if log_writer is not None and dist.get_rank() == 0:
文幕地方's avatar
文幕地方 已提交
347 348
                log_writer.log_metrics(
                    metrics=train_stats.get(), prefix="TRAIN", step=global_step)
W
WenmuZhou 已提交
349

350 351 352
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
353
                logs = train_stats.log()
L
LDOUBLEV 已提交
354

355 356 357 358
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
X
xiaoting 已提交
359 360
                    '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
                    'ips: {:.5f} samples/s, eta: {}'.format(
361 362 363 364 365
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
366
                logger.info(strs)
367

文幕地方's avatar
文幕地方 已提交
368
                total_samples = 0
369 370
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
371 372
            # eval
            if global_step > start_eval_step and \
373 374
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
375 376 377 378 379 380 381
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
382 383 384 385 386
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
387
                    model_type,
文幕地方's avatar
文幕地方 已提交
388 389
                    extra_input=extra_input,
                    scaler=scaler)
L
LDOUBLEV 已提交
390 391 392
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
393 394

                # logger metric
395
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
396 397
                    log_writer.log_metrics(
                        metrics=cur_metric, prefix="EVAL", step=global_step)
398

L
LDOUBLEV 已提交
399
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
400
                        main_indicator]:
L
LDOUBLEV 已提交
401
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
402 403 404 405 406 407
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
408
                        config,
W
WenmuZhou 已提交
409 410 411
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
412 413
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
414
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
415 416 417 418
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
419
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
420 421 422 423 424 425 426 427 428 429 430 431
                    log_writer.log_metrics(
                        metrics={
                            "best_{}".format(main_indicator):
                            best_model_dict[main_indicator]
                        },
                        prefix="EVAL",
                        step=global_step)

                    log_writer.log_model(
                        is_best=True,
                        prefix="best_accuracy",
                        metadata=best_model_dict)
432

文幕地方's avatar
文幕地方 已提交
433
            reader_start = time.time()
W
WenmuZhou 已提交
434 435 436 437 438 439
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
440
                config,
W
WenmuZhou 已提交
441 442 443
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
444 445
                epoch=epoch,
                global_step=global_step)
446

447 448
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix="latest")
449

W
WenmuZhou 已提交
450 451 452 453 454 455
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
456
                config,
W
WenmuZhou 已提交
457 458 459
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
460 461
                epoch=epoch,
                global_step=global_step)
462
            if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
463 464
                log_writer.log_model(
                    is_best=False, prefix='iter_epoch_{}'.format(epoch))
465

L
LDOUBLEV 已提交
466
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
467 468
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
469 470
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
L
LDOUBLEV 已提交
471 472 473
    return


M
refine  
MissPenguin 已提交
474 475 476 477
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
478
         model_type=None,
文幕地方's avatar
文幕地方 已提交
479 480
         extra_input=False,
         scaler=None):
W
WenmuZhou 已提交
481 482 483 484
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
485 486 487 488 489
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
490 491
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
X
xiaoting 已提交
492
        sum_images = 0
W
WenmuZhou 已提交
493
        for idx, batch in enumerate(valid_dataloader):
494
            if idx >= max_iter:
W
WenmuZhou 已提交
495
                break
W
fix bug  
WenmuZhou 已提交
496
            images = batch[0]
W
WenmuZhou 已提交
497
            start = time.time()
文幕地方's avatar
文幕地方 已提交
498 499 500 501 502 503

            # use amp
            if scaler:
                with paddle.amp.auto_cast(level='O2'):
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
504
                    elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
505
                        preds = model(batch)
X
xiaoting 已提交
506 507 508 509
                    elif model_type in ['sr']:
                        preds = model(batch)
                        sr_img = preds["sr_img"]
                        lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
510 511
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
512
                preds = to_float32(preds)
X
xiaoting 已提交
513
            else:
文幕地方's avatar
文幕地方 已提交
514 515
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
516
                elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
517
                    preds = model(batch)
X
xiaoting 已提交
518 519 520 521
                elif model_type in ['sr']:
                    preds = model(batch)
                    sr_img = preds["sr_img"]
                    lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
522 523 524
                else:
                    preds = model(images)

525 526 527 528 529 530
            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
531 532 533
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
534 535 536 537 538 539
            if model_type in ['table', 'kie']:
                if post_process_class is None:
                    eval_class(preds, batch_numpy)
                else:
                    post_result = post_process_class(preds, batch_numpy)
                    eval_class(post_result, batch_numpy)
X
xiaoting 已提交
540 541
            elif model_type in ['sr']:
                eval_class(preds, batch_numpy)
M
MissPenguin 已提交
542
            else:
543 544
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
545

W
fix bug  
WenmuZhou 已提交
546
            pbar.update(1)
W
WenmuZhou 已提交
547
            total_frame += len(images)
X
xiaoting 已提交
548
            sum_images += 1
L
LDOUBLEV 已提交
549 550
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
551

W
fix bug  
WenmuZhou 已提交
552
    pbar.close()
W
WenmuZhou 已提交
553
    model.train()
L
LDOUBLEV 已提交
554 555
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
556

T
tink2123 已提交
557

B
Bin Lu 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


607
def preprocess(is_train=False):
L
licx 已提交
608
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
609
    profiler_options = FLAGS.profiler_options
L
licx 已提交
610
    config = load_config(FLAGS.config)
611
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
612
    profile_dic = {"profiler_options": FLAGS.profiler_options}
613
    config = merge_config(config, profile_dic)
L
licx 已提交
614

W
WenmuZhou 已提交
615 616 617 618 619 620 621 622 623 624
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
Z
zhoujun 已提交
625
    logger = get_logger(log_file=log_file)
L
licx 已提交
626 627 628

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
X
xiaoting 已提交
629
    use_xpu = config['Global'].get('use_xpu', False)
L
licx 已提交
630

631 632 633 634 635 636
    # check if set use_xpu=True in paddlepaddle cpu/gpu version
    use_xpu = False
    if 'use_xpu' in config['Global']:
        use_xpu = config['Global']['use_xpu']
    check_xpu(use_xpu)

W
WenmuZhou 已提交
637 638
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
639
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
640
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
W
wangjingyeye 已提交
641
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
642
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
643
        'Gestalt', 'SLANet', 'RobustScanner'
W
WenmuZhou 已提交
644
    ]
L
licx 已提交
645

646
    if use_xpu:
X
xiaoting 已提交
647 648 649 650 651 652
        device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
    else:
        device = 'gpu:{}'.format(dist.ParallelEnv()
                                 .dev_id) if use_gpu else 'cpu'
    check_device(use_gpu, use_xpu)

W
WenmuZhou 已提交
653
    device = paddle.set_device(device)
D
dyning 已提交
654

D
dyning 已提交
655
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
656

657 658
    loggers = []

659
    if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
L
fix bug  
LDOUBLEV 已提交
660
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
661
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
A
andyjpaddle 已提交
662
        log_writer = VDLLogger(vdl_writer_path)
663
        loggers.append(log_writer)
文幕地方's avatar
文幕地方 已提交
664 665
    if ('use_wandb' in config['Global'] and
            config['Global']['use_wandb']) or 'wandb' in config:
666 667 668 669 670 671 672 673
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
674
        loggers.append(log_writer)
D
dyning 已提交
675
    else:
676
        log_writer = None
D
dyning 已提交
677
    print_dict(config, logger)
678 679 680 681 682 683

    if loggers:
        log_writer = Loggers(loggers)
    else:
        log_writer = None

D
dyning 已提交
684 685
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
686
    return config, device, logger, log_writer