program.py 15.5 KB
Newer Older
M
refine  
MissPenguin 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
LDOUBLEV 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20
import sys
21
import platform
L
LDOUBLEV 已提交
22 23
import yaml
import time
W
WenmuZhou 已提交
24 25 26 27 28 29
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
30 31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
D
dyning 已提交
32 33 34 35
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
L
LDOUBLEV 已提交
36

D
dyning 已提交
37

L
LDOUBLEV 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

农夫三拳_'s avatar
农夫三拳_ 已提交
79 80
default_config = {'Global': {'debug': False, }}

L
LDOUBLEV 已提交
81 82 83 84 85 86 87 88

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
农夫三拳_'s avatar
农夫三拳_ 已提交
89
    merge_config(default_config)
L
LDOUBLEV 已提交
90 91
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
W
WenmuZhou 已提交
92
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
L
LDOUBLEV 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    return global_config


def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
111 112 113 114
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
L
LDOUBLEV 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
W
WenmuZhou 已提交
135
        if use_gpu and not paddle.is_compiled_with_cuda():
W
WenmuZhou 已提交
136
            print(err)
L
LDOUBLEV 已提交
137 138 139 140 141
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
142
def train(config,
D
dyning 已提交
143 144 145
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
146 147 148 149 150 151 152 153 154 155 156
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
          vdl_writer=None):
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
L
LDOUBLEV 已提交
157 158 159 160
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
W
WenmuZhou 已提交
161

D
dyning 已提交
162
    global_step = 0
163 164
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
L
LDOUBLEV 已提交
165 166 167 168
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
W
WenmuZhou 已提交
169 170 171 172 173
        if len(valid_dataloader) == 0:
            logger.info(
                'No Images in eval dataset, evaluation during training will be disabled'
            )
            start_eval_step = 1e111
L
LDOUBLEV 已提交
174 175 176
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
177 178
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
179 180
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
181 182 183 184
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
T
tink2123 已提交
185
    model_average = False
W
WenmuZhou 已提交
186 187
    model.train()

T
tink2123 已提交
188
    use_srn = config['Architecture']['algorithm'] == "SRN"
M
MissPenguin 已提交
189 190
    model_type = config['Architecture']['model_type']
    
W
WenmuZhou 已提交
191 192 193
    if 'start_epoch' in best_model_dict:
        start_epoch = best_model_dict['start_epoch']
    else:
T
tink2123 已提交
194
        start_epoch = 1
W
WenmuZhou 已提交
195

T
tink2123 已提交
196
    for epoch in range(start_epoch, epoch_num + 1):
197 198
        train_dataloader = build_dataloader(
            config, 'Train', device, logger, seed=epoch)
W
WenmuZhou 已提交
199 200 201 202
        train_batch_cost = 0.0
        train_reader_cost = 0.0
        batch_sum = 0
        batch_start = time.time()
J
Jane-Ding 已提交
203 204 205
        max_iter = len(train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(train_dataloader)
        for idx, batch in enumerate(train_dataloader):
W
WenmuZhou 已提交
206
            train_reader_cost += time.time() - batch_start
J
Jane-Ding 已提交
207 208
            if idx >= max_iter:
                break
W
WenmuZhou 已提交
209 210
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
211
            if use_srn:
T
tink2123 已提交
212 213
                others = batch[-4:]
                preds = model(images, others)
T
tink2123 已提交
214
                model_average = True
M
MissPenguin 已提交
215 216 217
            elif model_type == "table":
                others = batch[1:]
                preds = model(images, others)
T
tink2123 已提交
218 219
            else:
                preds = model(images)
W
WenmuZhou 已提交
220 221
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
D
dyning 已提交
222
            avg_loss.backward()
W
WenmuZhou 已提交
223 224
            optimizer.step()
            optimizer.clear_grad()
W
WenmuZhou 已提交
225 226 227 228

            train_batch_cost += time.time() - batch_start
            batch_sum += len(images)

D
dyning 已提交
229 230
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
231 232 233 234 235 236

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

L
LDOUBLEV 已提交
237
            if cal_metric_during_train:  # only rec and cls need
W
WenmuZhou 已提交
238
                batch = [item.numpy() for item in batch]
M
MissPenguin 已提交
239 240 241 242 243
                if model_type == 'table':
                    eval_class(preds, batch)
                else:
                    post_result = post_process_class(preds, batch[1])
                    eval_class(post_result, batch)
littletomatodonkey's avatar
fix doc  
littletomatodonkey 已提交
244 245
                metric = eval_class.get_metric()
                train_stats.update(metric)
W
WenmuZhou 已提交
246 247 248 249 250 251

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

252 253 254
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
W
WenmuZhou 已提交
255
                logs = train_stats.log()
W
WenmuZhou 已提交
256
                strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
W
WenmuZhou 已提交
257 258 259
                    epoch, epoch_num, global_step, logs, train_reader_cost /
                    print_batch_step, train_batch_cost / print_batch_step,
                    batch_sum, batch_sum / train_batch_cost)
W
WenmuZhou 已提交
260
                logger.info(strs)
W
WenmuZhou 已提交
261 262 263
                train_batch_cost = 0.0
                train_reader_cost = 0.0
                batch_sum = 0
W
WenmuZhou 已提交
264 265 266
            # eval
            if global_step > start_eval_step and \
                    (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
T
tink2123 已提交
267 268 269 270 271 272 273
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
T
tink2123 已提交
274 275 276 277 278
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
M
refine  
MissPenguin 已提交
279
                    "table",
T
tink2123 已提交
280
                    use_srn=use_srn)
L
LDOUBLEV 已提交
281 282 283
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
W
WenmuZhou 已提交
284 285 286

                # logger metric
                if vdl_writer is not None:
L
LDOUBLEV 已提交
287
                    for k, v in cur_metric.items():
W
WenmuZhou 已提交
288 289
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
L
LDOUBLEV 已提交
290 291
                                                  cur_metric[k], global_step)
                if cur_metric[main_indicator] >= best_model_dict[
W
WenmuZhou 已提交
292
                        main_indicator]:
L
LDOUBLEV 已提交
293
                    best_model_dict.update(cur_metric)
W
WenmuZhou 已提交
294 295 296 297 298 299 300 301 302
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
303 304
                        epoch=epoch,
                        global_step=global_step)
L
LDOUBLEV 已提交
305
                best_str = 'best metric, {}'.format(', '.join([
W
WenmuZhou 已提交
306 307 308 309 310 311 312 313 314
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
            global_step += 1
T
tink2123 已提交
315
            optimizer.clear_grad()
316
            batch_start = time.time()
W
WenmuZhou 已提交
317 318 319 320 321 322 323 324 325
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
326 327
                epoch=epoch,
                global_step=global_step)
W
WenmuZhou 已提交
328 329 330 331 332 333 334 335 336
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
337 338
                epoch=epoch,
                global_step=global_step)
L
LDOUBLEV 已提交
339
    best_str = 'best metric, {}'.format(', '.join(
W
WenmuZhou 已提交
340 341 342 343
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
L
LDOUBLEV 已提交
344 345 346
    return


T
tink2123 已提交
347
def eval(model, valid_dataloader, post_process_class, eval_class,
M
MissPenguin 已提交
348
         model_type, use_srn=False):
W
WenmuZhou 已提交
349 350 351 352
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
W
fix bug  
WenmuZhou 已提交
353
        pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
354 355
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
W
WenmuZhou 已提交
356
        for idx, batch in enumerate(valid_dataloader):
357
            if idx >= max_iter:
W
WenmuZhou 已提交
358
                break
W
fix bug  
WenmuZhou 已提交
359
            images = batch[0]
W
WenmuZhou 已提交
360
            start = time.time()
T
tink2123 已提交
361 362

            if use_srn:
X
xiaoting 已提交
363 364 365 366
                others = batch[-4:]
                preds = model(images, others)
            else:
                preds = model(images)
W
WenmuZhou 已提交
367 368 369 370 371

            batch = [item.numpy() for item in batch]
            # Obtain usable results from post-processing methods
            total_time += time.time() - start
            # Evaluate the results of the current batch
M
MissPenguin 已提交
372 373 374 375 376
            if model_type == 'table':
                eval_class(preds, batch)
            else:
                post_result = post_process_class(preds, batch[1])
                eval_class(post_result, batch)
W
fix bug  
WenmuZhou 已提交
377
            pbar.update(1)
W
WenmuZhou 已提交
378
            total_frame += len(images)
L
LDOUBLEV 已提交
379 380
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
D
dyning 已提交
381

W
fix bug  
WenmuZhou 已提交
382
    pbar.close()
W
WenmuZhou 已提交
383
    model.train()
L
LDOUBLEV 已提交
384 385
    metric['fps'] = total_frame / total_time
    return metric
L
licx 已提交
386

T
tink2123 已提交
387

388
def preprocess(is_train=False):
L
licx 已提交
389 390 391 392 393 394 395 396
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

W
WenmuZhou 已提交
397 398
    alg = config['Architecture']['algorithm']
    assert alg in [
J
Jethong 已提交
399
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
M
MissPenguin 已提交
400
        'CLS', 'PGNet', 'Distillation', 'TableAttn'
W
WenmuZhou 已提交
401
    ]
L
licx 已提交
402

W
WenmuZhou 已提交
403 404
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
D
dyning 已提交
405

D
dyning 已提交
406
    config['Global']['distributed'] = dist.get_world_size() != 1
407 408 409 410 411 412 413 414 415 416 417
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
D
dyning 已提交
418 419
    if config['Global']['use_visualdl']:
        from visualdl import LogWriter
L
fix bug  
LDOUBLEV 已提交
420
        save_model_dir = config['Global']['save_model_dir']
D
dyning 已提交
421 422 423 424 425 426 427 428 429
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer