test_roidb_source.py 1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
import os
import time
import unittest
import sys
import logging

import set_env
from data import build_source


class TestRoiDbSource(unittest.TestCase):
    """Test cases for dataset.source.roidb_source
    """

    @classmethod
    def setUpClass(cls):
        """ setup
        """
        anno_path = set_env.coco_data['TRAIN']['ANNO_FILE']
        image_dir = set_env.coco_data['TRAIN']['IMAGE_DIR']
        cls.config = {
            'data_cf': {
                'anno_file': anno_path,
                'image_dir': image_dir,
                'samples': 100,
                'load_img': True
            },
            'cname2cid': None
        }

    @classmethod
    def tearDownClass(cls):
        """ tearDownClass """
        pass

    def test_basic(self):
        """ test basic apis 'next/size/drained'
        """
        roi_source = build_source(self.config)
        for i, sample in enumerate(roi_source):
            self.assertTrue('image' in sample)
            self.assertGreater(len(sample['image']), 0)
        self.assertTrue(roi_source.drained())
        self.assertEqual(i + 1, roi_source.size())

    def test_reset(self):
        """ test functions 'reset/epoch_id'
        """
        roi_source = build_source(self.config)

        self.assertTrue(roi_source.next() is not None)
        self.assertEqual(roi_source.epoch_id(), 0)

        roi_source.reset()

        self.assertEqual(roi_source.epoch_id(), 1)
        self.assertTrue(roi_source.next() is not None)


if __name__ == '__main__':
    unittest.main()