test_ssd_vgg16_512_coco2017.py 2.3 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9
# coding=utf-8
import os
import unittest

import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub

10
image_dir = '../image_dataset/object_detection/'
W
wuzewu 已提交
11

12 13

class TestSSDVGG512(unittest.TestCase):
W
wuzewu 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    @classmethod
    def setUpClass(self):
        """Prepare the environment once before execution of all tests."""
        self.ssd = hub.Module(name="ssd_vgg16_512_coco2017")

    @classmethod
    def tearDownClass(self):
        """clean up the environment after the execution of all tests."""
        self.ssd = None

    def setUp(self):
        self.test_prog = fluid.Program()
        "Call setUp() to prepare environment\n"

    def tearDown(self):
        "Call tearDown to restore environment.\n"
        self.test_prog = None

    def test_context(self):
        with fluid.program_guard(self.test_prog):
34
            get_prediction = True
W
wuzewu 已提交
35
            inputs, outputs, program = self.ssd.context(
36
                pretrained=True, trainable=True, get_prediction=get_prediction)
W
wuzewu 已提交
37
            image = inputs["image"]
38 39 40 41 42
            im_size = inputs["im_size"]
            if get_prediction:
                bbox_out = outputs['bbox_out']
            else:
                body_features = outputs['body_features']
W
wuzewu 已提交
43 44 45 46 47

    def test_object_detection(self):
        with fluid.program_guard(self.test_prog):
            zebra = cv2.imread(os.path.join(image_dir,
                                            'zebra.jpg')).astype('float32')
48
            zebras = [zebra, zebra]
W
wuzewu 已提交
49 50 51 52 53
            ## only paths
            print(
                self.ssd.object_detection(
                    paths=[os.path.join(image_dir, 'cat.jpg')]))
            ## only images
54
            print(self.ssd.object_detection(images=zebras))
W
wuzewu 已提交
55 56 57 58 59 60 61 62
            ## paths and images
            print(
                self.ssd.object_detection(
                    paths=[
                        os.path.join(image_dir, 'cat.jpg'),
                        os.path.join(image_dir, 'dog.jpg'),
                        os.path.join(image_dir, 'giraffe.jpg')
                    ],
63
                    images=zebras,
W
wuzewu 已提交
64 65 66 67 68 69
                    batch_size=2,
                    score_thresh=0.5))


if __name__ == "__main__":
    suite = unittest.TestSuite()
70 71
    suite.addTest(TestSSDVGG512('test_object_detection'))
    suite.addTest(TestSSDVGG512('test_context'))
W
wuzewu 已提交
72 73
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)