提交 c089b67b 编写于 作者: W walloollaw 提交者: qingqing01

support custom reader in ppdet.data (#2965)

上级 3a6c1f95
......@@ -42,14 +42,7 @@ __all__ = [
]
def create_reader(feed, max_iter=0, args_path=None):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
"""
def _prepare_data_config(feed, args_path):
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
dataset_home = args_path if args_path else feed.dataset.dataset_dir
......@@ -72,9 +65,7 @@ def create_reader(feed, max_iter=0, args_path=None):
if getattr(feed, 'use_process', None) is not None:
use_process = feed.use_process
mode = feed.mode
data_config = {
mode: {
'ANNO_FILE': feed.dataset.annotation,
'IMAGE_DIR': feed.dataset.image_dir,
'USE_DEFAULT_LABEL': feed.dataset.use_default_label,
......@@ -84,10 +75,26 @@ def create_reader(feed, max_iter=0, args_path=None):
'MIXUP_EPOCH': mixup_epoch,
'TYPE': type(feed.dataset).__source__
}
}
if len(getattr(feed.dataset, 'images', [])) > 0:
data_config[mode]['IMAGES'] = feed.dataset.images
data_config['IMAGES'] = feed.dataset.images
return data_config
def create_reader(feed, max_iter=0, args_path=None, my_source=None):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
my_source (callable): callable function to create a source iterator
which is used to provide source data in 'ppdet.data.reader'
"""
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
data_config = _prepare_data_config(feed, args_path)
transform_config = {
'WORKER_CONF': {
......@@ -130,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None):
ops.append(op_dict)
transform_config['OPS'] = ops
reader = Reader(data_config, {mode: transform_config}, max_iter)
return reader._make_reader(mode)
return Reader.create(feed.mode, data_config,
transform_config, max_iter, my_source)
# XXX batch transforms are only stubs for now, actually handled by `post_map`
......
......@@ -40,14 +40,17 @@ class Reader(object):
self._cname2cid = None
assert isinstance(self._maxiter, Integral), "maxiter should be int"
def _make_reader(self, mode):
def _make_reader(self, mode, my_source=None):
"""Build reader for training or validation"""
if my_source is None:
file_conf = self._data_cf[mode]
# 1, Build data source
sc_conf = {'data_cf': file_conf, 'cname2cid': self._cname2cid}
sc = build_source(sc_conf)
else:
sc = my_source
# 2, Buid a transformed dataset
ops = self._trans_conf[mode]['OPS']
......@@ -87,7 +90,7 @@ class Reader(object):
if mode.lower() == 'train':
if self._cname2cid is not None:
logger.warn('cname2cid already set, it will be overridden')
self._cname2cid = sc.cname2cid
self._cname2cid = getattr(sc, 'cname2cid', None)
# 3, Build a reader
maxit = -1 if self._maxiter <= 0 else self._maxiter
......@@ -120,3 +123,15 @@ class Reader(object):
def test(self):
"""Build reader for inference"""
return self._make_reader('TEST')
@classmethod
def create(cls, mode, data_config,
transform_config, max_iter=-1,
my_source=None, ret_iter=True):
""" create a specific reader """
reader = Reader({mode: data_config},
{mode: transform_config}, max_iter)
if ret_iter:
return reader._make_reader(mode, my_source)
else:
return reader
......@@ -20,6 +20,7 @@ import copy
from .roidb_source import RoiDbSource
from .simple_source import SimpleSource
from .iterator_source import IteratorSource
def build_source(config):
......@@ -40,11 +41,13 @@ def build_source(config):
}
"""
if 'data_cf' in config:
data_cf = {k.lower(): v for k, v in config['data_cf'].items()}
data_cf = config['data_cf']
data_cf['cname2cid'] = config['cname2cid']
else:
data_cf = config
data_cf = {k.lower(): v for k, v in data_cf.items()}
args = copy.deepcopy(data_cf)
# defaut type is 'RoiDbSource'
source_type = 'RoiDbSource'
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import copy
import logging
logger = logging.getLogger(__name__)
from ..dataset import Dataset
class IteratorSource(Dataset):
"""
Load data samples from iterator in stream mode
Args:
iter_maker (callable): callable function to generate a iter
samples (int): number of samples to load, -1 means all
"""
def __init__(self,
iter_maker,
samples=-1,
**kwargs):
super(IteratorSource, self).__init__()
self._epoch = -1
self._iter_maker = iter_maker
self._data_iter = None
self._pos = -1
self._drained = False
self._samples = samples
self._sample_num = -1
def next(self):
if self._epoch < 0:
self.reset()
if self._data_iter is not None:
try:
sample = next(self._data_iter)
self._pos += 1
ret = sample
except StopIteration as e:
if self._sample_num <= 0:
self._sample_num = self._pos
elif self._sample_num != self._pos:
logger.info('num of loaded samples is different '
'with previouse setting[prev:%d,now:%d]' % (self._sample_num, self._pos))
self._sample_num = self._pos
self._data_iter = None
self._drained = True
raise e
else:
raise StopIteration("no more data in " + str(self))
if self._samples > 0 and self._pos >= self._samples:
self._data_iter = None
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
return ret
def reset(self):
if self._data_iter is None:
self._data_iter = self._iter_maker()
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
def size(self):
return self._sample_num
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
......@@ -7,6 +7,7 @@ import unittest
import test_loader
import test_operator
import test_roidb_source
import test_iterator_source
import test_transformer
import test_reader
......@@ -17,6 +18,7 @@ if __name__ == '__main__':
test_loader.TestLoader,
test_operator.TestBase,
test_roidb_source.TestRoiDbSource,
test_iterator_source.TestIteratorSource,
test_transformer.TestTransformer,
test_reader.TestReader,
]
......
......@@ -3,6 +3,9 @@ import os
import six
import logging
import matplotlib
matplotlib.use('Agg', force=False)
prefix = os.path.dirname(os.path.abspath(__file__))
#coco data for testing
......
import os
import time
import unittest
import sys
import logging
import set_env
from ppdet.data.source import IteratorSource
def _generate_iter_maker(num=10):
def _reader():
for i in range(num):
yield {'image': 'image_' + str(i), 'label': i}
return _reader
class TestIteratorSource(unittest.TestCase):
"""Test cases for dataset.source.roidb_source
"""
@classmethod
def setUpClass(cls):
""" setup
"""
pass
@classmethod
def tearDownClass(cls):
""" tearDownClass """
pass
def test_basic(self):
""" test basic apis 'next/size/drained'
"""
iter_maker = _generate_iter_maker()
iter_source = IteratorSource(iter_maker)
for i, sample in enumerate(iter_source):
self.assertTrue('image' in sample)
self.assertGreater(len(sample['image']), 0)
self.assertTrue(iter_source.drained())
self.assertEqual(i + 1, iter_source.size())
def test_reset(self):
""" test functions 'reset/epoch_id'
"""
iter_maker = _generate_iter_maker()
iter_source = IteratorSource(iter_maker)
self.assertTrue(iter_source.next() is not None)
self.assertEqual(iter_source.epoch_id(), 0)
iter_source.reset()
self.assertEqual(iter_source.epoch_id(), 1)
self.assertTrue(iter_source.next() is not None)
if __name__ == '__main__':
unittest.main()
......@@ -8,6 +8,8 @@ import yaml
import set_env
from ppdet.data.reader import Reader
from ppdet.data.source import build_source
from ppdet.data.source import IteratorSource
class TestReader(unittest.TestCase):
......@@ -114,6 +116,31 @@ class TestReader(unittest.TestCase):
self.assertEqual(out[0][5].shape[1], 1)
self.assertGreaterEqual(ct, rcnn._maxiter)
def test_create(self):
""" Test create a reader using my source
"""
def _my_data_reader():
mydata = build_source(self.rcnn_conf['DATA']['TRAIN'])
for i, sample in enumerate(mydata):
yield sample
my_source = IteratorSource(_my_data_reader)
mode = 'TRAIN'
train_rd = Reader.create(mode,
self.rcnn_conf['DATA'][mode],
self.rcnn_conf['TRANSFORM'][mode],
max_iter=10, my_source=my_source)
out = None
for sample in train_rd():
out = sample
self.assertTrue(sample is not None)
self.assertEqual(out[0][0].shape[0], 3)
self.assertEqual(out[0][1].shape[0], 3)
self.assertEqual(out[0][3].shape[1], 4)
self.assertEqual(out[0][4].shape[1], 1)
self.assertEqual(out[0][5].shape[1], 1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册