program.py 24.9 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27
import paddle
import paddle.distributed as dist
from tqdm import tqdm
X
xiaoting 已提交
28 29
import cv2
import numpy as np
W
WenmuZhou 已提交
30 31
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
32 33
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
34
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
35
from ppocr.utils.logging import get_logger
36
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
L
LDOUBLEV 已提交
37
from ppocr.utils import profiler
D
dyning 已提交
38
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
39

D
dyning 已提交
40

L
LDOUBLEV 已提交
41 42 43 44 45 46 47
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
48 49 50 51 52
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
53 54
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
55
        )
L
LDOUBLEV 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
84 85
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
86 87


88
def merge_config(config, opts):
L
LDOUBLEV 已提交
89 90 91 92 93 94
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
95
    for key, value in opts.items():
L
LDOUBLEV 已提交
96
        if "." not in key:
97 98
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
99
            else:
100
                config[key] = value
L
LDOUBLEV 已提交
101 102
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
103
            assert (
104
                sub_keys[0] in config
105 106
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
107 108
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
109 110 111 112 113
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
114
    return config
L
LDOUBLEV 已提交
115 116


117
def check_device(use_gpu, use_xpu=False, use_npu=False):
L
LDOUBLEV 已提交
118 119 120 121
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
X
xiaoting 已提交
122 123 124 125
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
L
LDOUBLEV 已提交
126 127 128
          "model on CPU"

    try:
X
xiaoting 已提交
129 130
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
W
WenmuZhou 已提交
131
        if use_gpu and not paddle.is_compiled_with_cuda():
X
xiaoting 已提交
132 133 134 135
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
L
LDOUBLEV 已提交
136
            sys.exit(1)
137 138
        if use_npu and not paddle.device.is_compiled_with_npu():
            print(err.format("use_npu", "npu", "npu", "use_npu"))
139 140 141 142
            sys.exit(1)
    except Exception as e:
        pass

文幕地方's avatar
文幕地方 已提交
143

文幕地方's avatar
文幕地方 已提交
144 145 146 147 148
def to_float32(preds):
    if isinstance(preds, dict):
        for k in preds:
            if isinstance(preds[k], dict) or isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
149 150
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
151 152 153 154 155 156
    elif isinstance(preds, list):
        for k in range(len(preds)):
            if isinstance(preds[k], dict):
                preds[k] = to_float32(preds[k])
            elif isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
157 158 159
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
    elif isinstance(preds, paddle.Tensor):
160
        preds = preds.astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
161
    return preds
162

文幕地方's avatar
文幕地方 已提交
163

W
WenmuZhou 已提交
164
def train(config,
D
dyning 已提交
165 166 167
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
168 169 170 171 172 173 174 175
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
176
          log_writer=None,
文幕地方's avatar
文幕地方 已提交
177
          scaler=None,
文幕地方's avatar
文幕地方 已提交
178 179
          amp_level='O2',
          amp_custom_black_list=[]):
W
WenmuZhou 已提交
180 181
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
182
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
L
LDOUBLEV 已提交
183 184 185 186
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
187
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
188

D
dyning 已提交
189
    global_step = 0
190 191
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
192 193 194 195
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
196 197
        if len(valid_dataloader) == 0:
            logger.info(
198 199
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
200 201
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
202
        logger.info(
203 204
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
205
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
206 207
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
208 209
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
210 211 212 213
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
214
    model_average = False
W
WenmuZhou 已提交
215 216
    model.train()

T
tink2123 已提交
217
    use_srn = config['Architecture']['algorithm'] == "SRN"
A
andyjpaddle 已提交
218
    extra_input_models = [
X
xiaoting 已提交
219 220
        "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
        "RobustScanner"
A
andyjpaddle 已提交
221
    ]
A
andyjpaddle 已提交
222
    extra_input = False
A
andyjpaddle 已提交
223
    if config['Architecture']['algorithm'] == 'Distillation':
A
andyjpaddle 已提交
224 225 226
        for key in config['Architecture']["Models"]:
            extra_input = extra_input or config['Architecture']['Models'][key][
                'algorithm'] in extra_input_models
A
andyjpaddle 已提交
227 228
    else:
        extra_input = config['Architecture']['algorithm'] in extra_input_models
229
    try:
L
fix bug  
LDOUBLEV 已提交
230
        model_type = config['Architecture']['model_type']
231
    except:
L
fix bug  
LDOUBLEV 已提交
232
        model_type = None
A
andyjpaddle 已提交
233

T
tink2123 已提交
234
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
235

236 237 238 239
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
240 241
    train_reader_cost = 0.0
    train_batch_cost = 0.0
242
    reader_start = time.time()
243
    eta_meter = AverageMeter()
244 245 246

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
247

T
tink2123 已提交
248
    for epoch in range(start_epoch, epoch_num + 1):
249 250 251 252 253
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
X
xiaoting 已提交
254

W
WenmuZhou 已提交
255
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
256
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
257
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
258
            if idx >= max_iter:
W
WenmuZhou 已提交
259 260 261
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
262
            if use_srn:
T
tink2123 已提交
263
                model_average = True
S
stephon 已提交
264 265
            # use amp
            if scaler:
266 267 268
                with paddle.amp.auto_cast(
                        level=amp_level,
                        custom_black_list=amp_custom_black_list):
S
stephon 已提交
269 270
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
271
                    elif model_type in ["kie"]:
A
andyjpaddle 已提交
272
                        preds = model(batch)
S
stephon 已提交
273 274
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
275 276 277 278 279 280
                preds = to_float32(preds)
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
T
tink2123 已提交
281
            else:
S
stephon 已提交
282 283
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
284
                elif model_type in ["kie", 'sr']:
L
LDOUBLEV 已提交
285
                    preds = model(batch)
S
stephon 已提交
286 287
                else:
                    preds = model(images)
文幕地方's avatar
文幕地方 已提交
288 289
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
S
stephon 已提交
290 291
                avg_loss.backward()
                optimizer.step()
X
xiaoting 已提交
292

W
WenmuZhou 已提交
293
            optimizer.clear_grad()
W
WenmuZhou 已提交
294

295 296
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
X
xiaoting 已提交
297
                if model_type in ['kie', 'sr']:
298
                    eval_class(preds, batch)
文幕地方's avatar
文幕地方 已提交
299 300 301
                elif model_type in ['table']:
                    post_result = post_process_class(preds, batch)
                    eval_class(post_result, batch)
302
                else:
A
andyjpaddle 已提交
303 304 305 306
                    if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
                                                  ]:  # for multi head loss
                        post_result = post_process_class(
                            preds['ctc'], batch[1])  # for CTC head out
A
andyjpaddle 已提交
307 308 309
                    elif config['Loss']['name'] in ['VLLoss']:
                        post_result = post_process_class(preds, batch[1],
                                                         batch[-1])
A
andyjpaddle 已提交
310 311
                    else:
                        post_result = post_process_class(preds, batch[1])
312 313 314 315
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

316 317 318
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
319
            global_step += 1
文幕地方's avatar
文幕地方 已提交
320
            total_samples += len(images)
W
WenmuZhou 已提交
321

D
dyning 已提交
322 323
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
324 325 326 327 328 329

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

330
            if log_writer is not None and dist.get_rank() == 0:
文幕地方's avatar
文幕地方 已提交
331 332
                log_writer.log_metrics(
                    metrics=train_stats.get(), prefix="TRAIN", step=global_step)
W
WenmuZhou 已提交
333

334 335 336
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
337
                logs = train_stats.log()
L
LDOUBLEV 已提交
338

339 340 341 342
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
X
xiaoting 已提交
343 344
                    '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
                    'ips: {:.5f} samples/s, eta: {}'.format(
345 346 347 348 349
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
350
                logger.info(strs)
351

文幕地方's avatar
文幕地方 已提交
352
                total_samples = 0
353 354
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
355 356
            # eval
            if global_step > start_eval_step and \
357 358
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
359 360 361 362 363 364 365
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
366 367 368 369 370
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
371
                    model_type,
文幕地方's avatar
文幕地方 已提交
372
                    extra_input=extra_input,
文幕地方's avatar
文幕地方 已提交
373 374 375
                    scaler=scaler,
                    amp_level=amp_level,
                    amp_custom_black_list=amp_custom_black_list)
L
LDOUBLEV 已提交
376 377 378
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
379 380

                # logger metric
381
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
382 383
                    log_writer.log_metrics(
                        metrics=cur_metric, prefix="EVAL", step=global_step)
384

L
LDOUBLEV 已提交
385
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
386
                        main_indicator]:
L
LDOUBLEV 已提交
387
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
388 389 390 391 392 393
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
394
                        config,
W
WenmuZhou 已提交
395 396 397
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
398 399
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
400
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
401 402 403 404
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
405
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
406 407 408 409 410 411 412 413 414 415 416 417
                    log_writer.log_metrics(
                        metrics={
                            "best_{}".format(main_indicator):
                            best_model_dict[main_indicator]
                        },
                        prefix="EVAL",
                        step=global_step)

                    log_writer.log_model(
                        is_best=True,
                        prefix="best_accuracy",
                        metadata=best_model_dict)
418

文幕地方's avatar
文幕地方 已提交
419
            reader_start = time.time()
W
WenmuZhou 已提交
420 421 422 423 424 425
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
426
                config,
W
WenmuZhou 已提交
427 428 429
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
430 431
                epoch=epoch,
                global_step=global_step)
432

433 434
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix="latest")
435

W
WenmuZhou 已提交
436 437 438 439 440 441
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
442
                config,
W
WenmuZhou 已提交
443 444 445
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
446 447
                epoch=epoch,
                global_step=global_step)
448
            if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
449 450
                log_writer.log_model(
                    is_best=False, prefix='iter_epoch_{}'.format(epoch))
451

L
LDOUBLEV 已提交
452
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
453 454
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
455 456
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
L
LDOUBLEV 已提交
457 458 459
    return


M
refine  
MissPenguin 已提交
460 461 462 463
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
464
         model_type=None,
文幕地方's avatar
文幕地方 已提交
465
         extra_input=False,
文幕地方's avatar
文幕地方 已提交
466 467
         scaler=None,
         amp_level='O2',
468
         amp_custom_black_list=[]):
W
WenmuZhou 已提交
469 470 471 472
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
473 474 475 476 477
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
478 479
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
X
xiaoting 已提交
480
        sum_images = 0
W
WenmuZhou 已提交
481
        for idx, batch in enumerate(valid_dataloader):
482
            if idx >= max_iter:
W
WenmuZhou 已提交
483
                break
W
fix bug  
WenmuZhou 已提交
484
            images = batch[0]
W
WenmuZhou 已提交
485
            start = time.time()
文幕地方's avatar
文幕地方 已提交
486 487 488

            # use amp
            if scaler:
489 490 491
                with paddle.amp.auto_cast(
                        level=amp_level,
                        custom_black_list=amp_custom_black_list):
文幕地方's avatar
文幕地方 已提交
492 493
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
494
                    elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
495
                        preds = model(batch)
X
xiaoting 已提交
496 497 498 499
                    elif model_type in ['sr']:
                        preds = model(batch)
                        sr_img = preds["sr_img"]
                        lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
500 501
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
502
                preds = to_float32(preds)
X
xiaoting 已提交
503
            else:
文幕地方's avatar
文幕地方 已提交
504 505
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
506
                elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
507
                    preds = model(batch)
X
xiaoting 已提交
508 509 510 511
                elif model_type in ['sr']:
                    preds = model(batch)
                    sr_img = preds["sr_img"]
                    lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
512 513 514
                else:
                    preds = model(images)

515 516 517 518 519 520
            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
521 522 523
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
524 525 526 527 528 529
            if model_type in ['table', 'kie']:
                if post_process_class is None:
                    eval_class(preds, batch_numpy)
                else:
                    post_result = post_process_class(preds, batch_numpy)
                    eval_class(post_result, batch_numpy)
X
xiaoting 已提交
530 531
            elif model_type in ['sr']:
                eval_class(preds, batch_numpy)
M
MissPenguin 已提交
532
            else:
533 534
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
535

W
fix bug  
WenmuZhou 已提交
536
            pbar.update(1)
W
WenmuZhou 已提交
537
            total_frame += len(images)
X
xiaoting 已提交
538
            sum_images += 1
L
LDOUBLEV 已提交
539 540
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
541

W
fix bug  
WenmuZhou 已提交
542
    pbar.close()
W
WenmuZhou 已提交
543
    model.train()
L
LDOUBLEV 已提交
544 545
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
546

T
tink2123 已提交
547

B
Bin Lu 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


597
def preprocess(is_train=False):
L
licx 已提交
598
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
599
    profiler_options = FLAGS.profiler_options
L
licx 已提交
600
    config = load_config(FLAGS.config)
601
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
602
    profile_dic = {"profiler_options": FLAGS.profiler_options}
603
    config = merge_config(config, profile_dic)
L
licx 已提交
604

W
WenmuZhou 已提交
605 606 607 608 609 610 611 612 613 614
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
Z
zhoujun 已提交
615
    logger = get_logger(log_file=log_file)
L
licx 已提交
616 617

    # check if set use_gpu=True in paddlepaddle cpu version
618
    use_gpu = config['Global'].get('use_gpu', False)
X
xiaoting 已提交
619
    use_xpu = config['Global'].get('use_xpu', False)
620
    use_npu = config['Global'].get('use_npu', False)
621

W
WenmuZhou 已提交
622 623
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
624
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
625
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
W
wangjingyeye 已提交
626
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
627
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
H
huangjun12 已提交
628
        'Gestalt', 'SLANet', 'RobustScanner', 'CT'
W
WenmuZhou 已提交
629
    ]
L
licx 已提交
630

631
    if use_xpu:
X
xiaoting 已提交
632
        device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
633 634
    elif use_npu:
        device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
X
xiaoting 已提交
635 636 637
    else:
        device = 'gpu:{}'.format(dist.ParallelEnv()
                                 .dev_id) if use_gpu else 'cpu'
638
    check_device(use_gpu, use_xpu, use_npu)
X
xiaoting 已提交
639

W
WenmuZhou 已提交
640
    device = paddle.set_device(device)
D
dyning 已提交
641

D
dyning 已提交
642
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
643

644 645
    loggers = []

646
    if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
L
fix bug  
LDOUBLEV 已提交
647
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
648
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
A
andyjpaddle 已提交
649
        log_writer = VDLLogger(vdl_writer_path)
650
        loggers.append(log_writer)
文幕地方's avatar
文幕地方 已提交
651 652
    if ('use_wandb' in config['Global'] and
            config['Global']['use_wandb']) or 'wandb' in config:
653 654 655 656 657 658 659 660
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
661
        loggers.append(log_writer)
D
dyning 已提交
662
    else:
663
        log_writer = None
D
dyning 已提交
664
    print_dict(config, logger)
665 666 667 668 669 670

    if loggers:
        log_writer = Loggers(loggers)
    else:
        log_writer = None

D
dyning 已提交
671 672
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
673
    return config, device, logger, log_writer