program.py 13.7 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WenmuZhou 已提交
19
import os
L
LDOUBLEV 已提交
20 21 22
import sys
import yaml
import time
W
WenmuZhou 已提交
23 24 25 26 27 28
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

L
LDOUBLEV 已提交
29 30
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
D
dyning 已提交
31 32 33 34
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
L
LDOUBLEV 已提交
35

D
dyning 已提交
36

L
LDOUBLEV 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

农夫三拳_'s avatar
农夫三拳_ 已提交
78 79
default_config = {'Global': {'debug': False, }}

L
LDOUBLEV 已提交
80 81 82 83 84 85 86 87

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
农夫三拳_'s avatar
农夫三拳_ 已提交
88
    merge_config(default_config)
L
LDOUBLEV 已提交
89 90
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
W
WenmuZhou 已提交
91
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
L
LDOUBLEV 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    return global_config


def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
T
tink2123 已提交
110 111 112 113
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
L
LDOUBLEV 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
W
WenmuZhou 已提交
134
        if use_gpu and not paddle.is_compiled_with_cuda():
W
WenmuZhou 已提交
135
            print(err)
L
LDOUBLEV 已提交
136 137 138 139 140
            sys.exit(1)
    except Exception as e:
        pass


W
WenmuZhou 已提交
141
def train(config,
D
dyning 已提交
142 143 144
          train_dataloader,
          valid_dataloader,
          device,
W
WenmuZhou 已提交
145 146 147 148 149 150 151 152 153 154 155
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
          vdl_writer=None):
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
L
LDOUBLEV 已提交
156 157 158 159
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
W
WenmuZhou 已提交
160

D
dyning 已提交
161
    global_step = 0
L
LDOUBLEV 已提交
162 163 164 165 166 167 168
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
            format(start_eval_step, eval_batch_step))
L
LDOUBLEV 已提交
169 170
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
171 172
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
W
WenmuZhou 已提交
173 174 175 176 177 178 179 180 181
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
    model.train()

    if 'start_epoch' in best_model_dict:
        start_epoch = best_model_dict['start_epoch']
    else:
T
tink2123 已提交
182
        start_epoch = 0
W
WenmuZhou 已提交
183

T
tink2123 已提交
184
    for epoch in range(start_epoch, epoch_num):
D
dyning 已提交
185
        if epoch > 0:
W
fix bug  
WenmuZhou 已提交
186
            train_dataloader = build_dataloader(config, 'Train', device, logger)
W
WenmuZhou 已提交
187 188 189 190
        train_batch_cost = 0.0
        train_reader_cost = 0.0
        batch_sum = 0
        batch_start = time.time()
W
WenmuZhou 已提交
191
        for idx, batch in enumerate(train_dataloader):
W
WenmuZhou 已提交
192
            train_reader_cost += time.time() - batch_start
W
WenmuZhou 已提交
193 194 195 196
            if idx >= len(train_dataloader):
                break
            lr = optimizer.get_lr()
            images = batch[0]
T
tink2123 已提交
197 198 199 200 201
            if config['Architecture']['algorithm'] == "SRN":
                others = batch[-4:]
                preds = model(images, others)
            else:
                preds = model(images)
W
WenmuZhou 已提交
202 203
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
D
dyning 已提交
204
            avg_loss.backward()
W
WenmuZhou 已提交
205 206
            optimizer.step()
            optimizer.clear_grad()
W
WenmuZhou 已提交
207 208 209 210

            train_batch_cost += time.time() - batch_start
            batch_sum += len(images)

D
dyning 已提交
211 212
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
W
WenmuZhou 已提交
213 214 215 216 217 218

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

T
tink2123 已提交
219
            #cal_metric_during_train = False
W
WenmuZhou 已提交
220 221 222 223 224 225 226 227 228 229 230 231
            if cal_metric_during_train:  # onlt rec and cls need
                batch = [item.numpy() for item in batch]
                post_result = post_process_class(preds, batch[1])
                eval_class(post_result, batch)
                metirc = eval_class.get_metric()
                train_stats.update(metirc)

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

D
dyning 已提交
232 233
            if dist.get_rank(
            ) == 0 and global_step > 0 and global_step % print_batch_step == 0:
W
WenmuZhou 已提交
234
                logs = train_stats.log()
W
WenmuZhou 已提交
235
                strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
W
WenmuZhou 已提交
236 237 238
                    epoch, epoch_num, global_step, logs, train_reader_cost /
                    print_batch_step, train_batch_cost / print_batch_step,
                    batch_sum, batch_sum / train_batch_cost)
W
WenmuZhou 已提交
239
                logger.info(strs)
W
WenmuZhou 已提交
240 241 242
                train_batch_cost = 0.0
                train_reader_cost = 0.0
                batch_sum = 0
W
WenmuZhou 已提交
243 244 245
            # eval
            if global_step > start_eval_step and \
                    (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
D
dyning 已提交
246
                cur_metirc = eval(model, valid_dataloader, post_process_class,
W
WenmuZhou 已提交
247
                                  eval_class)
W
WenmuZhou 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
                cur_metirc_str = 'cur metirc, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
                logger.info(cur_metirc_str)

                # logger metric
                if vdl_writer is not None:
                    for k, v in cur_metirc.items():
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
                                                  cur_metirc[k], global_step)
                if cur_metirc[main_indicator] >= best_model_dict[
                        main_indicator]:
                    best_model_dict.update(cur_metirc)
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
                        epoch=epoch)
                best_str = 'best metirc, {}'.format(', '.join([
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
            global_step += 1
281
            batch_start = time.time()
W
WenmuZhou 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
                epoch=epoch)
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
                epoch=epoch)
    best_str = 'best metirc, {}'.format(', '.join(
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
L
LDOUBLEV 已提交
307 308 309
    return


W
WenmuZhou 已提交
310
def eval(model, valid_dataloader, post_process_class, eval_class):
W
WenmuZhou 已提交
311 312 313 314
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
W
fix bug  
WenmuZhou 已提交
315
        pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
W
WenmuZhou 已提交
316 317 318
        for idx, batch in enumerate(valid_dataloader):
            if idx >= len(valid_dataloader):
                break
W
fix bug  
WenmuZhou 已提交
319
            images = batch[0]
T
tink2123 已提交
320
            others = batch[-4:]
W
WenmuZhou 已提交
321
            start = time.time()
T
tink2123 已提交
322
            preds = model(images, others)
W
WenmuZhou 已提交
323 324 325 326 327 328 329

            batch = [item.numpy() for item in batch]
            # Obtain usable results from post-processing methods
            post_result = post_process_class(preds, batch[1])
            total_time += time.time() - start
            # Evaluate the results of the current batch
            eval_class(post_result, batch)
W
fix bug  
WenmuZhou 已提交
330
            pbar.update(1)
W
WenmuZhou 已提交
331 332 333
            total_frame += len(images)
        # Get final metirc,eg. acc or hmean
        metirc = eval_class.get_metric()
D
dyning 已提交
334

W
fix bug  
WenmuZhou 已提交
335
    pbar.close()
W
WenmuZhou 已提交
336 337 338
    model.train()
    metirc['fps'] = total_frame / total_time
    return metirc
L
licx 已提交
339

T
tink2123 已提交
340

341
def preprocess(is_train=False):
L
licx 已提交
342 343 344 345 346 347 348 349
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

W
WenmuZhou 已提交
350 351
    alg = config['Architecture']['algorithm']
    assert alg in [
W
WenmuZhou 已提交
352
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
W
WenmuZhou 已提交
353
    ]
L
licx 已提交
354

W
WenmuZhou 已提交
355 356
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
D
dyning 已提交
357

D
dyning 已提交
358
    config['Global']['distributed'] = dist.get_world_size() != 1
359 360 361 362 363 364 365 366 367 368 369
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
D
dyning 已提交
370 371 372 373 374 375 376 377 378 379 380
    if config['Global']['use_visualdl']:
        from visualdl import LogWriter
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer