test_iterator_source.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
import os
import time
import unittest
import sys
import logging

import set_env
from ppdet.data.source import IteratorSource


def _generate_iter_maker(num=10):
    def _reader():
        for i in range(num):
            yield {'image': 'image_' + str(i), 'label': i}

    return _reader

class TestIteratorSource(unittest.TestCase):
    """Test cases for dataset.source.roidb_source
    """

    @classmethod
    def setUpClass(cls):
        """ setup
        """
        pass

    @classmethod
    def tearDownClass(cls):
        """ tearDownClass """
        pass

    def test_basic(self):
        """ test basic apis 'next/size/drained'
        """
        iter_maker = _generate_iter_maker()
        iter_source = IteratorSource(iter_maker)
        for i, sample in enumerate(iter_source):
            self.assertTrue('image' in sample)
            self.assertGreater(len(sample['image']), 0)
        self.assertTrue(iter_source.drained())
        self.assertEqual(i + 1, iter_source.size())

    def test_reset(self):
        """ test functions 'reset/epoch_id'
        """
        iter_maker = _generate_iter_maker()
        iter_source = IteratorSource(iter_maker)

        self.assertTrue(iter_source.next() is not None)
        self.assertEqual(iter_source.epoch_id(), 0)

        iter_source.reset()

        self.assertEqual(iter_source.epoch_id(), 1)
        self.assertTrue(iter_source.next() is not None)


if __name__ == '__main__':
    unittest.main()