program.py 25.2 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
24
import datetime
W
WenmuZhou 已提交
25 26 27
import paddle
import paddle.distributed as dist
from tqdm import tqdm
X
xiaoting 已提交
28 29
import cv2
import numpy as np
W
WenmuZhou 已提交
30 31
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
32 33
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
34
from ppocr.utils.utility import print_dict, AverageMeter
D
dyning 已提交
35
from ppocr.utils.logging import get_logger
36
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
L
LDOUBLEV 已提交
37
from ppocr.utils import profiler
D
dyning 已提交
38
from ppocr.data import build_dataloader
L
LDOUBLEV 已提交
39

D
dyning 已提交
40

L
LDOUBLEV 已提交
41 42 43 44 45 46 47
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")
L
LDOUBLEV 已提交
48 49 50 51 52
        self.add_argument(
            '-p',
            '--profiler_options',
            type=str,
            default=None,
53 54
            help='The option of profiler, which should be in format ' \
                 '\"key1=value1;key2=value2;key3=value3\".'
L
LDOUBLEV 已提交
55
        )
L
LDOUBLEV 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
84 85
    config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
    return config
L
LDOUBLEV 已提交
86 87


88
def merge_config(config, opts):
L
LDOUBLEV 已提交
89 90 91 92 93 94
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
95
    for key, value in opts.items():
L
LDOUBLEV 已提交
96
        if "." not in key:
97 98
            if isinstance(value, dict) and key in config:
                config[key].update(value)
L
LDOUBLEV 已提交
99
            else:
100
                config[key] = value
L
LDOUBLEV 已提交
101 102
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
103
            assert (
104
                sub_keys[0] in config
105 106
            ), "the sub_keys can only be one of global_config: {}, but get: " \
               "{}, please check your running command".format(
107 108
                config.keys(), sub_keys[0])
            cur = config[sub_keys[0]]
L
LDOUBLEV 已提交
109 110 111 112 113
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]
114
    return config
L
LDOUBLEV 已提交
115 116


117
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
L
LDOUBLEV 已提交
118 119 120 121
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
X
xiaoting 已提交
122 123 124 125
    err = "Config {} cannot be set as true while your paddle " \
          "is not compiled with {} ! \nPlease try: \n" \
          "\t1. Install paddlepaddle to run model on {} \n" \
          "\t2. Set {} as false in config file to run " \
L
LDOUBLEV 已提交
126 127 128
          "model on CPU"

    try:
X
xiaoting 已提交
129 130
        if use_gpu and use_xpu:
            print("use_xpu and use_gpu can not both be ture.")
W
WenmuZhou 已提交
131
        if use_gpu and not paddle.is_compiled_with_cuda():
X
xiaoting 已提交
132 133 134 135
            print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
            sys.exit(1)
        if use_xpu and not paddle.device.is_compiled_with_xpu():
            print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
L
LDOUBLEV 已提交
136
            sys.exit(1)
137 138
        if use_npu and not paddle.device.is_compiled_with_npu():
            print(err.format("use_npu", "npu", "npu", "use_npu"))
139
            sys.exit(1)
140 141 142
        if use_mlu and not paddle.device.is_compiled_with_mlu():
            print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
            sys.exit(1)
143 144 145
    except Exception as e:
        pass

文幕地方's avatar
文幕地方 已提交
146

文幕地方's avatar
文幕地方 已提交
147 148 149 150 151
def to_float32(preds):
    if isinstance(preds, dict):
        for k in preds:
            if isinstance(preds[k], dict) or isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
152 153
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
154 155 156 157 158 159
    elif isinstance(preds, list):
        for k in range(len(preds)):
            if isinstance(preds[k], dict):
                preds[k] = to_float32(preds[k])
            elif isinstance(preds[k], list):
                preds[k] = to_float32(preds[k])
文幕地方's avatar
文幕地方 已提交
160 161 162
            elif isinstance(preds[k], paddle.Tensor):
                preds[k] = preds[k].astype(paddle.float32)
    elif isinstance(preds, paddle.Tensor):
163
        preds = preds.astype(paddle.float32)
文幕地方's avatar
文幕地方 已提交
164
    return preds
165

文幕地方's avatar
文幕地方 已提交
166

W
WenmuZhou 已提交
167
def train(config,
D
dyning 已提交
168 169 170
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
171 172 173 174 175 176 177 178
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
179
          log_writer=None,
文幕地方's avatar
文幕地方 已提交
180
          scaler=None,
文幕地方's avatar
文幕地方 已提交
181 182
          amp_level='O2',
          amp_custom_black_list=[]):
W
WenmuZhou 已提交
183 184
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
185
    calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
L
LDOUBLEV 已提交
186 187 188 189
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
L
LDOUBLEV 已提交
190
    profiler_options = config['profiler_options']
W
WenmuZhou 已提交
191

D
dyning 已提交
192
    global_step = 0
193 194
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
195 196 197 198
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
199 200
        if len(valid_dataloader) == 0:
            logger.info(
201 202
                'No Images in eval dataset, evaluation during training ' \
                'will be disabled'
W
WenmuZhou 已提交
203 204
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
205
        logger.info(
206 207
            "During the training process, after the {}th iteration, " \
            "an evaluation is run every {} iterations".
L
LDOUBLEV 已提交
208
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
209 210
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
211 212
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
213 214 215 216
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
217
    model_average = False
W
WenmuZhou 已提交
218 219
    model.train()

T
tink2123 已提交
220
    use_srn = config['Architecture']['algorithm'] == "SRN"
A
andyjpaddle 已提交
221
    extra_input_models = [
X
xiaoting 已提交
222
        "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
223
        "RobustScanner", "RFL", 'DRRG'
A
andyjpaddle 已提交
224
    ]
A
andyjpaddle 已提交
225
    extra_input = False
A
andyjpaddle 已提交
226
    if config['Architecture']['algorithm'] == 'Distillation':
A
andyjpaddle 已提交
227 228 229
        for key in config['Architecture']["Models"]:
            extra_input = extra_input or config['Architecture']['Models'][key][
                'algorithm'] in extra_input_models
A
andyjpaddle 已提交
230 231
    else:
        extra_input = config['Architecture']['algorithm'] in extra_input_models
232
    try:
L
fix bug  
LDOUBLEV 已提交
233
        model_type = config['Architecture']['model_type']
234
    except:
L
fix bug  
LDOUBLEV 已提交
235
        model_type = None
A
andyjpaddle 已提交
236

T
tink2123 已提交
237
    algorithm = config['Architecture']['algorithm']
T
tink2123 已提交
238

239 240 241 242
    start_epoch = best_model_dict[
        'start_epoch'] if 'start_epoch' in best_model_dict else 1

    total_samples = 0
243 244
    train_reader_cost = 0.0
    train_batch_cost = 0.0
245
    reader_start = time.time()
246
    eta_meter = AverageMeter()
247 248 249

    max_iter = len(train_dataloader) - 1 if platform.system(
    ) == "Windows" else len(train_dataloader)
W
WenmuZhou 已提交
250

T
tink2123 已提交
251
    for epoch in range(start_epoch, epoch_num + 1):
252 253 254 255 256
        if train_dataloader.dataset.need_reset:
            train_dataloader = build_dataloader(
                config, 'Train', device, logger, seed=epoch)
            max_iter = len(train_dataloader) - 1 if platform.system(
            ) == "Windows" else len(train_dataloader)
X
xiaoting 已提交
257

W
WenmuZhou 已提交
258
        for idx, batch in enumerate(train_dataloader):
L
LDOUBLEV 已提交
259
            profiler.add_profiler_step(profiler_options)
文幕地方's avatar
文幕地方 已提交
260
            train_reader_cost += time.time() - reader_start
J
Jane-Ding 已提交
261
            if idx >= max_iter:
W
WenmuZhou 已提交
262 263 264
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
265
            if use_srn:
T
tink2123 已提交
266
                model_average = True
S
stephon 已提交
267 268
            # use amp
            if scaler:
269 270 271
                with paddle.amp.auto_cast(
                        level=amp_level,
                        custom_black_list=amp_custom_black_list):
S
stephon 已提交
272 273
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
274
                    elif model_type in ["kie"]:
A
andyjpaddle 已提交
275
                        preds = model(batch)
S
stephon 已提交
276 277
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
278 279 280 281 282 283
                preds = to_float32(preds)
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
                scaled_avg_loss = scaler.scale(avg_loss)
                scaled_avg_loss.backward()
                scaler.minimize(optimizer, scaled_avg_loss)
T
tink2123 已提交
284
            else:
S
stephon 已提交
285 286
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
287
                elif model_type in ["kie", 'sr']:
L
LDOUBLEV 已提交
288
                    preds = model(batch)
S
stephon 已提交
289 290
                else:
                    preds = model(images)
文幕地方's avatar
文幕地方 已提交
291 292
                loss = loss_class(preds, batch)
                avg_loss = loss['loss']
S
stephon 已提交
293 294
                avg_loss.backward()
                optimizer.step()
X
xiaoting 已提交
295

W
WenmuZhou 已提交
296
            optimizer.clear_grad()
W
WenmuZhou 已提交
297

298 299
            if cal_metric_during_train and epoch % calc_epoch_interval == 0:  # only rec and cls need
                batch = [item.numpy() for item in batch]
X
xiaoting 已提交
300
                if model_type in ['kie', 'sr']:
301
                    eval_class(preds, batch)
文幕地方's avatar
文幕地方 已提交
302 303 304
                elif model_type in ['table']:
                    post_result = post_process_class(preds, batch)
                    eval_class(post_result, batch)
305
                else:
A
andyjpaddle 已提交
306 307 308 309
                    if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
                                                  ]:  # for multi head loss
                        post_result = post_process_class(
                            preds['ctc'], batch[1])  # for CTC head out
A
andyjpaddle 已提交
310 311 312
                    elif config['Loss']['name'] in ['VLLoss']:
                        post_result = post_process_class(preds, batch[1],
                                                         batch[-1])
A
andyjpaddle 已提交
313 314
                    else:
                        post_result = post_process_class(preds, batch[1])
315 316 317 318
                    eval_class(post_result, batch)
                metric = eval_class.get_metric()
                train_stats.update(metric)

319 320 321
            train_batch_time = time.time() - reader_start
            train_batch_cost += train_batch_time
            eta_meter.update(train_batch_time)
322
            global_step += 1
文幕地方's avatar
文幕地方 已提交
323
            total_samples += len(images)
W
WenmuZhou 已提交
324

D
dyning 已提交
325 326
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
327 328 329 330 331 332

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

333
            if log_writer is not None and dist.get_rank() == 0:
文幕地方's avatar
文幕地方 已提交
334 335
                log_writer.log_metrics(
                    metrics=train_stats.get(), prefix="TRAIN", step=global_step)
W
WenmuZhou 已提交
336

337 338 339
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
340
                logs = train_stats.log()
L
LDOUBLEV 已提交
341

342 343 344 345
                eta_sec = ((epoch_num + 1 - epoch) * \
                    len(train_dataloader) - idx - 1) * eta_meter.avg
                eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
                strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
X
xiaoting 已提交
346 347
                    '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
                    'ips: {:.5f} samples/s, eta: {}'.format(
348 349 350 351 352
                    epoch, epoch_num, global_step, logs,
                    train_reader_cost / print_batch_step,
                    train_batch_cost / print_batch_step,
                    total_samples / print_batch_step,
                    total_samples / train_batch_cost, eta_sec_format)
W
WenmuZhou 已提交
353
                logger.info(strs)
354

文幕地方's avatar
文幕地方 已提交
355
                total_samples = 0
356 357
                train_reader_cost = 0.0
                train_batch_cost = 0.0
W
WenmuZhou 已提交
358 359
            # eval
            if global_step > start_eval_step and \
360 361
                    (global_step - start_eval_step) % eval_batch_step == 0 \
                    and dist.get_rank() == 0:
T
tink2123 已提交
362 363 364 365 366 367 368
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
369 370 371 372 373
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
374
                    model_type,
文幕地方's avatar
文幕地方 已提交
375
                    extra_input=extra_input,
文幕地方's avatar
文幕地方 已提交
376 377 378
                    scaler=scaler,
                    amp_level=amp_level,
                    amp_custom_black_list=amp_custom_black_list)
L
LDOUBLEV 已提交
379 380 381
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
382 383

                # logger metric
384
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
385 386
                    log_writer.log_metrics(
                        metrics=cur_metric, prefix="EVAL", step=global_step)
387

L
LDOUBLEV 已提交
388
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
389
                        main_indicator]:
L
LDOUBLEV 已提交
390
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
391 392 393 394 395 396
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
397
                        config,
W
WenmuZhou 已提交
398 399 400
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
401 402
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
403
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
404 405 406 407
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
408
                if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
409 410 411 412 413 414 415 416 417 418 419 420
                    log_writer.log_metrics(
                        metrics={
                            "best_{}".format(main_indicator):
                            best_model_dict[main_indicator]
                        },
                        prefix="EVAL",
                        step=global_step)

                    log_writer.log_model(
                        is_best=True,
                        prefix="best_accuracy",
                        metadata=best_model_dict)
421

文幕地方's avatar
文幕地方 已提交
422
            reader_start = time.time()
W
WenmuZhou 已提交
423 424 425 426 427 428
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
429
                config,
W
WenmuZhou 已提交
430 431 432
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
433 434
                epoch=epoch,
                global_step=global_step)
435

436 437
            if log_writer is not None:
                log_writer.log_model(is_best=False, prefix="latest")
438

W
WenmuZhou 已提交
439 440 441 442 443 444
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
445
                config,
W
WenmuZhou 已提交
446 447 448
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
449 450
                epoch=epoch,
                global_step=global_step)
451
            if log_writer is not None:
文幕地方's avatar
文幕地方 已提交
452 453
                log_writer.log_model(
                    is_best=False, prefix='iter_epoch_{}'.format(epoch))
454

L
LDOUBLEV 已提交
455
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
456 457
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
458 459
    if dist.get_rank() == 0 and log_writer is not None:
        log_writer.close()
L
LDOUBLEV 已提交
460 461 462
    return


M
refine  
MissPenguin 已提交
463 464 465 466
def eval(model,
         valid_dataloader,
         post_process_class,
         eval_class,
L
LDOUBLEV 已提交
467
         model_type=None,
文幕地方's avatar
文幕地方 已提交
468
         extra_input=False,
文幕地方's avatar
文幕地方 已提交
469 470
         scaler=None,
         amp_level='O2',
471
         amp_custom_black_list=[]):
W
WenmuZhou 已提交
472 473 474 475
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
文幕地方's avatar
文幕地方 已提交
476 477 478 479 480
        pbar = tqdm(
            total=len(valid_dataloader),
            desc='eval model:',
            position=0,
            leave=True)
481 482
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
X
xiaoting 已提交
483
        sum_images = 0
W
WenmuZhou 已提交
484
        for idx, batch in enumerate(valid_dataloader):
485
            if idx >= max_iter:
W
WenmuZhou 已提交
486
                break
W
fix bug  
WenmuZhou 已提交
487
            images = batch[0]
W
WenmuZhou 已提交
488
            start = time.time()
文幕地方's avatar
文幕地方 已提交
489 490 491

            # use amp
            if scaler:
492 493 494
                with paddle.amp.auto_cast(
                        level=amp_level,
                        custom_black_list=amp_custom_black_list):
文幕地方's avatar
文幕地方 已提交
495 496
                    if model_type == 'table' or extra_input:
                        preds = model(images, data=batch[1:])
497
                    elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
498
                        preds = model(batch)
X
xiaoting 已提交
499 500 501 502
                    elif model_type in ['sr']:
                        preds = model(batch)
                        sr_img = preds["sr_img"]
                        lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
503 504
                    else:
                        preds = model(images)
文幕地方's avatar
文幕地方 已提交
505
                preds = to_float32(preds)
X
xiaoting 已提交
506
            else:
文幕地方's avatar
文幕地方 已提交
507 508
                if model_type == 'table' or extra_input:
                    preds = model(images, data=batch[1:])
509
                elif model_type in ["kie"]:
文幕地方's avatar
文幕地方 已提交
510
                    preds = model(batch)
X
xiaoting 已提交
511 512 513 514
                elif model_type in ['sr']:
                    preds = model(batch)
                    sr_img = preds["sr_img"]
                    lr_img = preds["lr_img"]
文幕地方's avatar
文幕地方 已提交
515 516 517
                else:
                    preds = model(images)

518 519 520 521 522 523
            batch_numpy = []
            for item in batch:
                if isinstance(item, paddle.Tensor):
                    batch_numpy.append(item.numpy())
                else:
                    batch_numpy.append(item)
W
WenmuZhou 已提交
524 525 526
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
527 528 529 530 531 532
            if model_type in ['table', 'kie']:
                if post_process_class is None:
                    eval_class(preds, batch_numpy)
                else:
                    post_result = post_process_class(preds, batch_numpy)
                    eval_class(post_result, batch_numpy)
X
xiaoting 已提交
533 534
            elif model_type in ['sr']:
                eval_class(preds, batch_numpy)
M
MissPenguin 已提交
535
            else:
536 537
                post_result = post_process_class(preds, batch_numpy[1])
                eval_class(post_result, batch_numpy)
L
LDOUBLEV 已提交
538

W
fix bug  
WenmuZhou 已提交
539
            pbar.update(1)
W
WenmuZhou 已提交
540
            total_frame += len(images)
X
xiaoting 已提交
541
            sum_images += 1
L
LDOUBLEV 已提交
542 543
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
544

W
fix bug  
WenmuZhou 已提交
545
    pbar.close()
W
WenmuZhou 已提交
546
    model.train()
L
LDOUBLEV 已提交
547 548
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
549

T
tink2123 已提交
550

B
Bin Lu 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
def update_center(char_center, post_result, preds):
    result, label = post_result
    feats, logits = preds
    logits = paddle.argmax(logits, axis=-1)
    feats = feats.numpy()
    logits = logits.numpy()

    for idx_sample in range(len(label)):
        if result[idx_sample][0] == label[idx_sample][0]:
            feat = feats[idx_sample]
            logit = logits[idx_sample]
            for idx_time in range(len(logit)):
                index = logit[idx_time]
                if index in char_center.keys():
                    char_center[index][0] = (
                        char_center[index][0] * char_center[index][1] +
                        feat[idx_time]) / (char_center[index][1] + 1)
                    char_center[index][1] += 1
                else:
                    char_center[index] = [feat[idx_time], 1]
    return char_center


def get_center(model, eval_dataloader, post_process_class):
    pbar = tqdm(total=len(eval_dataloader), desc='get center:')
    max_iter = len(eval_dataloader) - 1 if platform.system(
    ) == "Windows" else len(eval_dataloader)
    char_center = dict()
    for idx, batch in enumerate(eval_dataloader):
        if idx >= max_iter:
            break
        images = batch[0]
        start = time.time()
        preds = model(images)

        batch = [item.numpy() for item in batch]
        # Obtain usable results from post-processing methods
        post_result = post_process_class(preds, batch[1])

        #update char_center
        char_center = update_center(char_center, post_result, preds)
        pbar.update(1)

    pbar.close()
    for key in char_center.keys():
        char_center[key] = char_center[key][0]
    return char_center


600
def preprocess(is_train=False):
L
licx 已提交
601
    FLAGS = ArgsParser().parse_args()
L
LDOUBLEV 已提交
602
    profiler_options = FLAGS.profiler_options
L
licx 已提交
603
    config = load_config(FLAGS.config)
604
    config = merge_config(config, FLAGS.opt)
L
LDOUBLEV 已提交
605
    profile_dic = {"profiler_options": FLAGS.profiler_options}
606
    config = merge_config(config, profile_dic)
L
licx 已提交
607

W
WenmuZhou 已提交
608 609 610 611 612 613 614 615 616 617
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
Z
zhoujun 已提交
618
    logger = get_logger(log_file=log_file)
L
licx 已提交
619 620

    # check if set use_gpu=True in paddlepaddle cpu version
621
    use_gpu = config['Global'].get('use_gpu', False)
X
xiaoting 已提交
622
    use_xpu = config['Global'].get('use_xpu', False)
623
    use_npu = config['Global'].get('use_npu', False)
624
    use_mlu = config['Global'].get('use_mlu', False)
625

W
WenmuZhou 已提交
626 627
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
628
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
T
tink2123 已提交
629
        'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
W
wangjingyeye 已提交
630
        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
631
        'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
632
        'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG'
W
WenmuZhou 已提交
633
    ]
L
licx 已提交
634

635
    if use_xpu:
X
xiaoting 已提交
636
        device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
637 638
    elif use_npu:
        device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
639 640
    elif use_mlu:
        device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
X
xiaoting 已提交
641 642 643
    else:
        device = 'gpu:{}'.format(dist.ParallelEnv()
                                 .dev_id) if use_gpu else 'cpu'
644
    check_device(use_gpu, use_xpu, use_npu, use_mlu)
X
xiaoting 已提交
645

W
WenmuZhou 已提交
646
    device = paddle.set_device(device)
D
dyning 已提交
647

D
dyning 已提交
648
    config['Global']['distributed'] = dist.get_world_size() != 1
W
WenmuZhou 已提交
649

650 651
    loggers = []

652
    if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
L
fix bug  
LDOUBLEV 已提交
653
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
654
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
A
andyjpaddle 已提交
655
        log_writer = VDLLogger(vdl_writer_path)
656
        loggers.append(log_writer)
文幕地方's avatar
文幕地方 已提交
657 658
    if ('use_wandb' in config['Global'] and
            config['Global']['use_wandb']) or 'wandb' in config:
659 660 661 662 663 664 665 666
        save_dir = config['Global']['save_model_dir']
        wandb_writer_path = "{}/wandb".format(save_dir)
        if "wandb" in config:
            wandb_params = config['wandb']
        else:
            wandb_params = dict()
        wandb_params.update({'save_dir': save_model_dir})
        log_writer = WandbLogger(**wandb_params, config=config)
667
        loggers.append(log_writer)
D
dyning 已提交
668
    else:
669
        log_writer = None
D
dyning 已提交
670
    print_dict(config, logger)
671 672 673 674 675 676

    if loggers:
        log_writer = Loggers(loggers)
    else:
        log_writer = None

D
dyning 已提交
677 678
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
679
    return config, device, logger, log_writer