From c089b67b80ee2a54a7b50c7d14dfdeb2cdf5269d Mon Sep 17 00:00:00 2001 From: walloollaw <37680514+walloollaw@users.noreply.github.com> Date: Wed, 31 Jul 2019 12:09:31 +0800 Subject: [PATCH] support custom reader in ppdet.data (#2965) --- ppdet/data/data_feed.py | 51 ++++++----- ppdet/data/reader.py | 27 ++++-- ppdet/data/source/__init__.py | 5 +- ppdet/data/source/iterator_source.py | 103 +++++++++++++++++++++++ ppdet/data/tests/run_all_tests.py | 2 + ppdet/data/tests/set_env.py | 3 + ppdet/data/tests/test_iterator_source.py | 60 +++++++++++++ ppdet/data/tests/test_reader.py | 27 ++++++ 8 files changed, 249 insertions(+), 29 deletions(-) create mode 100644 ppdet/data/source/iterator_source.py create mode 100644 ppdet/data/tests/test_iterator_source.py diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index e15b9da7c..dd85e83fe 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -42,14 +42,7 @@ __all__ = [ ] -def create_reader(feed, max_iter=0, args_path=None): - """ - Return iterable data reader. - - Args: - max_iter (int): number of iterations. - """ - +def _prepare_data_config(feed, args_path): # if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download dataset_home = args_path if args_path else feed.dataset.dataset_dir @@ -72,22 +65,36 @@ def create_reader(feed, max_iter=0, args_path=None): if getattr(feed, 'use_process', None) is not None: use_process = feed.use_process - mode = feed.mode data_config = { - mode: { - 'ANNO_FILE': feed.dataset.annotation, - 'IMAGE_DIR': feed.dataset.image_dir, - 'USE_DEFAULT_LABEL': feed.dataset.use_default_label, - 'IS_SHUFFLE': feed.shuffle, - 'SAMPLES': feed.samples, - 'WITH_BACKGROUND': feed.with_background, - 'MIXUP_EPOCH': mixup_epoch, - 'TYPE': type(feed.dataset).__source__ - } + 'ANNO_FILE': feed.dataset.annotation, + 'IMAGE_DIR': feed.dataset.image_dir, + 'USE_DEFAULT_LABEL': feed.dataset.use_default_label, + 'IS_SHUFFLE': feed.shuffle, + 'SAMPLES': feed.samples, + 'WITH_BACKGROUND': feed.with_background, + 'MIXUP_EPOCH': mixup_epoch, + 'TYPE': type(feed.dataset).__source__ } if len(getattr(feed.dataset, 'images', [])) > 0: - data_config[mode]['IMAGES'] = feed.dataset.images + data_config['IMAGES'] = feed.dataset.images + + return data_config + + +def create_reader(feed, max_iter=0, args_path=None, my_source=None): + """ + Return iterable data reader. + + Args: + max_iter (int): number of iterations. + my_source (callable): callable function to create a source iterator + which is used to provide source data in 'ppdet.data.reader' + """ + + # if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory + # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download + data_config = _prepare_data_config(feed, args_path) transform_config = { 'WORKER_CONF': { @@ -130,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None): ops.append(op_dict) transform_config['OPS'] = ops - reader = Reader(data_config, {mode: transform_config}, max_iter) - return reader._make_reader(mode) + return Reader.create(feed.mode, data_config, + transform_config, max_iter, my_source) # XXX batch transforms are only stubs for now, actually handled by `post_map` diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 497724ba0..5370bb9e4 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -40,14 +40,17 @@ class Reader(object): self._cname2cid = None assert isinstance(self._maxiter, Integral), "maxiter should be int" - def _make_reader(self, mode): + def _make_reader(self, mode, my_source=None): """Build reader for training or validation""" - file_conf = self._data_cf[mode] + if my_source is None: + file_conf = self._data_cf[mode] - # 1, Build data source + # 1, Build data source - sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid} - sc = build_source(sc_conf) + sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid} + sc = build_source(sc_conf) + else: + sc = my_source # 2, Buid a transformed dataset ops = self._trans_conf[mode]['OPS'] @@ -87,7 +90,7 @@ class Reader(object): if mode.lower() == 'train': if self._cname2cid is not None: logger.warn('cname2cid already set, it will be overridden') - self._cname2cid = sc.cname2cid + self._cname2cid = getattr(sc, 'cname2cid', None) # 3, Build a reader maxit = -1 if self._maxiter <= 0 else self._maxiter @@ -120,3 +123,15 @@ class Reader(object): def test(self): """Build reader for inference""" return self._make_reader('TEST') + + @classmethod + def create(cls, mode, data_config, + transform_config, max_iter=-1, + my_source=None, ret_iter=True): + """ create a specific reader """ + reader = Reader({mode: data_config}, + {mode: transform_config}, max_iter) + if ret_iter: + return reader._make_reader(mode, my_source) + else: + return reader diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 8e4910941..ca0d5c833 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -20,6 +20,7 @@ import copy from .roidb_source import RoiDbSource from .simple_source import SimpleSource +from .iterator_source import IteratorSource def build_source(config): @@ -40,11 +41,13 @@ def build_source(config): } """ if 'data_cf' in config: - data_cf = {k.lower(): v for k, v in config['data_cf'].items()} + data_cf = config['data_cf'] data_cf['cname2cid'] = config['cname2cid'] else: data_cf = config + data_cf = {k.lower(): v for k, v in data_cf.items()} + args = copy.deepcopy(data_cf) # defaut type is 'RoiDbSource' source_type = 'RoiDbSource' diff --git a/ppdet/data/source/iterator_source.py b/ppdet/data/source/iterator_source.py new file mode 100644 index 000000000..2785d4843 --- /dev/null +++ b/ppdet/data/source/iterator_source.py @@ -0,0 +1,103 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import copy + +import logging +logger = logging.getLogger(__name__) + +from ..dataset import Dataset + + +class IteratorSource(Dataset): + """ + Load data samples from iterator in stream mode + + Args: + iter_maker (callable): callable function to generate a iter + samples (int): number of samples to load, -1 means all + """ + + def __init__(self, + iter_maker, + samples=-1, + **kwargs): + super(IteratorSource, self).__init__() + self._epoch = -1 + + self._iter_maker = iter_maker + self._data_iter = None + self._pos = -1 + self._drained = False + self._samples = samples + self._sample_num = -1 + + def next(self): + if self._epoch < 0: + self.reset() + + if self._data_iter is not None: + try: + sample = next(self._data_iter) + self._pos += 1 + ret = sample + except StopIteration as e: + if self._sample_num <= 0: + self._sample_num = self._pos + elif self._sample_num != self._pos: + logger.info('num of loaded samples is different ' + 'with previouse setting[prev:%d,now:%d]' % (self._sample_num, self._pos)) + self._sample_num = self._pos + + self._data_iter = None + self._drained = True + raise e + else: + raise StopIteration("no more data in " + str(self)) + + if self._samples > 0 and self._pos >= self._samples: + self._data_iter = None + self._drained = True + raise StopIteration("no more data in " + str(self)) + else: + return ret + + def reset(self): + if self._data_iter is None: + self._data_iter = self._iter_maker() + + if self._epoch < 0: + self._epoch = 0 + else: + self._epoch += 1 + + self._pos = 0 + self._drained = False + + def size(self): + return self._sample_num + + def drained(self): + assert self._epoch >= 0, "the first epoch has not started yet" + return self._pos >= self.size() + + def epoch_id(self): + return self._epoch + diff --git a/ppdet/data/tests/run_all_tests.py b/ppdet/data/tests/run_all_tests.py index cc5b7622f..220b9f4fb 100644 --- a/ppdet/data/tests/run_all_tests.py +++ b/ppdet/data/tests/run_all_tests.py @@ -7,6 +7,7 @@ import unittest import test_loader import test_operator import test_roidb_source +import test_iterator_source import test_transformer import test_reader @@ -17,6 +18,7 @@ if __name__ == '__main__': test_loader.TestLoader, test_operator.TestBase, test_roidb_source.TestRoiDbSource, + test_iterator_source.TestIteratorSource, test_transformer.TestTransformer, test_reader.TestReader, ] diff --git a/ppdet/data/tests/set_env.py b/ppdet/data/tests/set_env.py index 2ded06a5d..0f97e0164 100644 --- a/ppdet/data/tests/set_env.py +++ b/ppdet/data/tests/set_env.py @@ -3,6 +3,9 @@ import os import six import logging +import matplotlib +matplotlib.use('Agg', force=False) + prefix = os.path.dirname(os.path.abspath(__file__)) #coco data for testing diff --git a/ppdet/data/tests/test_iterator_source.py b/ppdet/data/tests/test_iterator_source.py new file mode 100644 index 000000000..1d8604177 --- /dev/null +++ b/ppdet/data/tests/test_iterator_source.py @@ -0,0 +1,60 @@ +import os +import time +import unittest +import sys +import logging + +import set_env +from ppdet.data.source import IteratorSource + + +def _generate_iter_maker(num=10): + def _reader(): + for i in range(num): + yield {'image': 'image_' + str(i), 'label': i} + + return _reader + +class TestIteratorSource(unittest.TestCase): + """Test cases for dataset.source.roidb_source + """ + + @classmethod + def setUpClass(cls): + """ setup + """ + pass + + @classmethod + def tearDownClass(cls): + """ tearDownClass """ + pass + + def test_basic(self): + """ test basic apis 'next/size/drained' + """ + iter_maker = _generate_iter_maker() + iter_source = IteratorSource(iter_maker) + for i, sample in enumerate(iter_source): + self.assertTrue('image' in sample) + self.assertGreater(len(sample['image']), 0) + self.assertTrue(iter_source.drained()) + self.assertEqual(i + 1, iter_source.size()) + + def test_reset(self): + """ test functions 'reset/epoch_id' + """ + iter_maker = _generate_iter_maker() + iter_source = IteratorSource(iter_maker) + + self.assertTrue(iter_source.next() is not None) + self.assertEqual(iter_source.epoch_id(), 0) + + iter_source.reset() + + self.assertEqual(iter_source.epoch_id(), 1) + self.assertTrue(iter_source.next() is not None) + + +if __name__ == '__main__': + unittest.main() diff --git a/ppdet/data/tests/test_reader.py b/ppdet/data/tests/test_reader.py index bd6db801d..48fbc9d2c 100644 --- a/ppdet/data/tests/test_reader.py +++ b/ppdet/data/tests/test_reader.py @@ -8,6 +8,8 @@ import yaml import set_env from ppdet.data.reader import Reader +from ppdet.data.source import build_source +from ppdet.data.source import IteratorSource class TestReader(unittest.TestCase): @@ -114,6 +116,31 @@ class TestReader(unittest.TestCase): self.assertEqual(out[0][5].shape[1], 1) self.assertGreaterEqual(ct, rcnn._maxiter) + def test_create(self): + """ Test create a reader using my source + """ + def _my_data_reader(): + mydata = build_source(self.rcnn_conf['DATA']['TRAIN']) + for i, sample in enumerate(mydata): + yield sample + + my_source = IteratorSource(_my_data_reader) + mode = 'TRAIN' + train_rd = Reader.create(mode, + self.rcnn_conf['DATA'][mode], + self.rcnn_conf['TRANSFORM'][mode], + max_iter=10, my_source=my_source) + + out = None + for sample in train_rd(): + out = sample + self.assertTrue(sample is not None) + self.assertEqual(out[0][0].shape[0], 3) + self.assertEqual(out[0][1].shape[0], 3) + self.assertEqual(out[0][3].shape[1], 4) + self.assertEqual(out[0][4].shape[1], 1) + self.assertEqual(out[0][5].shape[1], 1) + if __name__ == '__main__': unittest.main() -- GitLab