module.py 7.5 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5
# coding=utf-8
from __future__ import absolute_import

import ast
import argparse
6
import os
W
wuzewu 已提交
7 8
from functools import partial

9
import yaml
jm_12138's avatar
jm_12138 已提交
10
import paddle
W
wuzewu 已提交
11
import numpy as np
jm_12138's avatar
jm_12138 已提交
12 13
import paddle.static
from paddle.inference import Config, create_predictor
14
from paddlehub.module.module import moduleinfo, runnable, serving
W
wuzewu 已提交
15

jm_12138's avatar
jm_12138 已提交
16 17
from .processor import load_label_info, postprocess, base64_to_cv2
from .data_feed import reader
W
wuzewu 已提交
18 19 20 21


@moduleinfo(
    name="ssd_vgg16_300_coco2017",
jm_12138's avatar
jm_12138 已提交
22
    version="1.1.0",
W
wuzewu 已提交
23 24 25
    type="cv/object_detection",
    summary="SSD with backbone VGG16, trained with dataset COCO.",
    author="paddlepaddle",
W
wuzewu 已提交
26
    author_email="paddle-dev@baidu.com")
jm_12138's avatar
jm_12138 已提交
27 28 29 30 31 32
class SSDVGG16:
    def __init__(self):
        self.default_pretrained_model_path = os.path.join(
            self.directory, "ssd_vgg16_300_model", "model")
        self.label_names = load_label_info(
            os.path.join(self.directory, "label_file.txt"))
33
        self.model_config = None
W
wuzewu 已提交
34 35 36
        self._set_config()

    def _set_config(self):
jm_12138's avatar
jm_12138 已提交
37 38 39 40 41 42
        """
        predictor config setting.
        """
        model = self.default_pretrained_model_path+'.pdmodel'
        params = self.default_pretrained_model_path+'.pdiparams'
        cpu_config = Config(model, params)
W
wuzewu 已提交
43 44
        cpu_config.disable_glog_info()
        cpu_config.disable_gpu()
45
        cpu_config.switch_ir_optim(False)
jm_12138's avatar
jm_12138 已提交
46
        self.cpu_predictor = create_predictor(cpu_config)
W
wuzewu 已提交
47 48 49 50 51 52 53 54

        try:
            _places = os.environ["CUDA_VISIBLE_DEVICES"]
            int(_places[0])
            use_gpu = True
        except:
            use_gpu = False
        if use_gpu:
jm_12138's avatar
jm_12138 已提交
55
            gpu_config = Config(model, params)
W
wuzewu 已提交
56 57
            gpu_config.disable_glog_info()
            gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
jm_12138's avatar
jm_12138 已提交
58
            self.gpu_predictor = create_predictor(gpu_config)
W
wuzewu 已提交
59

60 61 62 63 64 65 66
        # model config setting.
        if not self.model_config:
            with open(os.path.join(self.directory, 'config.yml')) as fp:
                self.model_config = yaml.load(fp.read(), Loader=yaml.FullLoader)

        self.multi_box_head_config = self.model_config['MultiBoxHead']
        self.output_decoder_config = self.model_config['SSDOutputDecoder']
W
wuzewu 已提交
67 68 69 70 71

    def object_detection(self,
                         paths=None,
                         images=None,
                         batch_size=1,
72
                         use_gpu=False,
W
wuzewu 已提交
73 74 75 76 77
                         output_dir='detection_result',
                         score_thresh=0.5,
                         visualization=True):
        """API of Object Detection.

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        Args:
            paths (list[str]): The paths of images.
            images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
            batch_size (int): batch size.
            use_gpu (bool): Whether to use gpu.
            output_dir (str): The path to store output images.
            visualization (bool): Whether to save image or not.
            score_thresh (float): threshold for object detecion.

        Returns:
            res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
                data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
                    left (float): The X coordinate of the upper left corner of the bounding box;
                    top (float): The Y coordinate of the upper left corner of the bounding box;
                    right (float): The X coordinate of the lower right corner of the bounding box;
                    bottom (float): The Y coordinate of the lower right corner of the bounding box;
                    label (str): The label of detection result;
                    confidence (float): The confidence of detection result.
                save_path (str, optional): The path to save output images.
W
wuzewu 已提交
97
        """
98 99
        paths = paths if paths else list()
        data_reader = partial(reader, paths, images)
jm_12138's avatar
jm_12138 已提交
100
        batch_reader = paddle.batch(data_reader, batch_size=batch_size)
W
wuzewu 已提交
101 102
        res = []
        for iter_id, feed_data in enumerate(batch_reader()):
103 104
            feed_data = np.array(feed_data)

jm_12138's avatar
jm_12138 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
            predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
            input_names = predictor.get_input_names()
            input_handle = predictor.get_input_handle(input_names[0])
            input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))

            predictor.run()
            output_names = predictor.get_output_names()
            output_handle = predictor.get_output_handle(output_names[0])

            output = postprocess(paths=paths,
                                 images=images,
                                 data_out=output_handle,
                                 score_thresh=score_thresh,
                                 label_names=self.label_names,
                                 output_dir=output_dir,
                                 handle_id=iter_id * batch_size,
                                 visualization=visualization)
122
            res.extend(output)
W
wuzewu 已提交
123 124
        return res

125 126
    @serving
    def serving_method(self, images, **kwargs):
W
wuzewu 已提交
127
        """
128
        Run as a service.
W
wuzewu 已提交
129
        """
130
        images_decode = [base64_to_cv2(image) for image in images]
W
wuzewu 已提交
131
        results = self.object_detection(images=images_decode, **kwargs)
132
        return results
W
wuzewu 已提交
133 134 135

    @runnable
    def run_cmd(self, argvs):
136 137 138
        """
        Run as a command.
        """
W
wuzewu 已提交
139
        self.parser = argparse.ArgumentParser(
140 141
            description="Run the {} module.".format(self.name),
            prog='hub run {}'.format(self.name),
W
wuzewu 已提交
142 143
            usage='%(prog)s',
            add_help=True)
jm_12138's avatar
jm_12138 已提交
144 145
        self.arg_input_group = self.parser.add_argument_group(
            title="Input options", description="Input data. Required")
W
wuzewu 已提交
146
        self.arg_config_group = self.parser.add_argument_group(
jm_12138's avatar
jm_12138 已提交
147 148 149
            title="Config options",
            description=
            "Run configuration for controlling module behavior, not required.")
W
wuzewu 已提交
150 151 152
        self.add_module_config_arg()
        self.add_module_input_arg()
        args = self.parser.parse_args(argvs)
W
wuzewu 已提交
153
        results = self.object_detection(
154 155 156 157 158 159 160 161 162 163 164 165 166
            paths=[args.input_path],
            batch_size=args.batch_size,
            use_gpu=args.use_gpu,
            output_dir=args.output_dir,
            visualization=args.visualization,
            score_thresh=args.score_thresh)
        return results

    def add_module_config_arg(self):
        """
        Add the command config options.
        """
        self.arg_config_group.add_argument(
jm_12138's avatar
jm_12138 已提交
167 168 169 170
            '--use_gpu',
            type=ast.literal_eval,
            default=False,
            help="whether use GPU or not")
171
        self.arg_config_group.add_argument(
jm_12138's avatar
jm_12138 已提交
172 173 174 175
            '--output_dir',
            type=str,
            default='detection_result',
            help="The directory to save output images.")
176
        self.arg_config_group.add_argument(
jm_12138's avatar
jm_12138 已提交
177 178 179 180
            '--visualization',
            type=ast.literal_eval,
            default=False,
            help="whether to save output as images.")
181 182 183 184 185 186

    def add_module_input_arg(self):
        """
        Add the command input options.
        """
        self.arg_input_group.add_argument(
jm_12138's avatar
jm_12138 已提交
187 188 189 190 191 192 193 194 195 196 197
            '--input_path', type=str, help="path to image.")
        self.arg_input_group.add_argument(
            '--batch_size',
            type=ast.literal_eval,
            default=1,
            help="batch size.")
        self.arg_input_group.add_argument(
            '--score_thresh',
            type=ast.literal_eval,
            default=0.5,
            help="threshold for object detecion.")