reader.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# function:
#    Interface to build readers for detection data like COCO or VOC
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from numbers import Integral

import logging
from .source import build_source
from .transform import build_mapper, map, batch, batch_map

logger = logging.getLogger(__name__)


class Reader(object):
    """Interface to make readers for training or evaluation"""

    def __init__(self, data_cf, trans_conf, maxiter=-1):
        self._data_cf = data_cf
        self._trans_conf = trans_conf
        self._maxiter = maxiter
        self._cname2cid = None
        assert isinstance(self._maxiter, Integral), "maxiter should be int"

43
    def _make_reader(self, mode, my_source=None):
44
        """Build reader for training or validation"""
45 46
        if my_source is None:
            file_conf = self._data_cf[mode]
47

48
            # 1, Build data source
49

50 51 52 53
            sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid}
            sc = build_source(sc_conf)
        else:
            sc = my_source
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

        # 2, Buid a transformed dataset
        ops = self._trans_conf[mode]['OPS']
        batchsize = self._trans_conf[mode]['BATCH_SIZE']
        drop_last = False if 'DROP_LAST' not in \
            self._trans_conf[mode] else self._trans_conf[mode]['DROP_LAST']

        mapper = build_mapper(ops, {'is_train': mode == 'TRAIN'})

        worker_args = None
        if 'WORKER_CONF' in self._trans_conf[mode]:
            worker_args = self._trans_conf[mode]['WORKER_CONF']
            worker_args = {k.lower(): v for k, v in worker_args.items()}

        mapped_ds = map(sc, mapper, worker_args)
69 70 71 72
        # In VAL mode, gt_bbox, gt_label can be empty, and should
        # not be dropped
        batched_ds = batch(mapped_ds, batchsize, drop_last, 
                           drop_empty=(mode!="VAL"))
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

        trans_conf = {k.lower(): v for k, v in self._trans_conf[mode].items()}
        need_keys = {
            'is_padding',
            'coarsest_stride',
            'random_shapes',
            'multi_scales',
            'use_padded_im_info',
        }
        bm_config = {
            key: value
            for key, value in trans_conf.items() if key in need_keys
        }

        batched_ds = batch_map(batched_ds, bm_config)

        batched_ds.reset()
        if mode.lower() == 'train':
            if self._cname2cid is not None:
                logger.warn('cname2cid already set, it will be overridden')
93
            self._cname2cid = getattr(sc, 'cname2cid', None)
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

        # 3, Build a reader
        maxit = -1 if self._maxiter <= 0 else self._maxiter

        def _reader():
            n = 0
            while True:
                for _batch in batched_ds:
                    yield _batch
                    n += 1
                    if maxit > 0 and n == maxit:
                        return
                batched_ds.reset()
                if maxit <= 0:
                    return

        if hasattr(sc, 'get_imid2path'):
            _reader.imid2path = sc.get_imid2path()

        return _reader

    def train(self):
        """Build reader for training"""
        return self._make_reader('TRAIN')

    def val(self):
        """Build reader for validation"""
        return self._make_reader('VAL')

    def test(self):
        """Build reader for inference"""
        return self._make_reader('TEST')
126 127 128 129 130 131 132 133 134 135 136 137

    @classmethod
    def create(cls, mode, data_config,
            transform_config, max_iter=-1,
            my_source=None, ret_iter=True):
        """ create a specific reader """
        reader = Reader({mode: data_config},
            {mode: transform_config}, max_iter)
        if ret_iter:
            return reader._make_reader(mode, my_source)
        else:
            return reader