提交 e2759c78 编写于 作者: D dongdaxiang

add new version of rcnn detection example

上级 c8883c10
background
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import sys
from paddle_serving_app.reader.pddet import Detection
from paddle_serving_app.reader import File2Image, Sequential, Normalize, Resize, Transpose, Div, BGR2RGB, RCNNPostprocess
import numpy as np
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(640, 640), Transpose((2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9393'])
for i in range(100):
im = preprocess(sys.argv[2])
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
},
fetch=["multiclass_nms"])
fetch_map["image"] = sys.argv[2]
postprocess(fetch_map)
......@@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, RCNNPostprocess
......@@ -21,10 +21,14 @@ def transpose(img, transpose_target):
return img
def normalize(img, mean, std):
def normalize(img, mean, std, channel_first):
# need to optimize here
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
if channel_first:
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
else:
img_mean = np.array(mean).reshape((1, 1, 3))
img_std = np.array(std).reshape((1, 1, 3))
img -= img_mean
img /= img_std
return img
......@@ -45,12 +49,15 @@ def crop(img, target_size, center):
return img
def resize(img, target_size, interpolation):
def resize(img, target_size, max_size=2147483647, interpolation=None):
if isinstance(target_size, tuple):
resized_width = target_size[0]
resized_height = target_size[1]
resized_width = min(target_size[0], max_size)
resized_height = min(target_size[1], max_size)
else:
im_max_size = max(img.shape[0], img.shape[1])
percent = float(target_size) / min(img.shape[0], img.shape[1])
if np.round(percent * im_max_size) > max_size:
percent = float(max_size) / float(im_max_size)
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
......
......@@ -12,14 +12,170 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import os
import urllib
import numpy as np
import base64
import functional as F
from PIL import Image, ImageDraw
import json
_cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"}
class RCNNPostprocess(object):
def __init__(self, label_file, output_dir):
self.output_dir = output_dir
self.label_file = label_file
self.label_list = []
with open(label_file) as fin:
for line in fin:
self.label_list.append(line.strip())
self.clsid2catid = {i: i for i in range(len(self.label_list))}
self.catid2name = {i: name for i, name in enumerate(self.label_list)}
def _offset_to_lengths(self, lod):
offset = lod[0]
lengths = [offset[i + 1] - offset[i] for i in range(len(offset) - 1)]
return [lengths]
def _bbox2out(self, results, clsid2catid, is_bbox_normalized=False):
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
if bboxes.shape == (1, 1) or bboxes is None:
continue
k = 0
for i in range(len(lengths)):
num = lengths[i]
for j in range(num):
dt = bboxes[k]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (clsid2catid[int(clsid)])
if is_bbox_normalized:
xmin, ymin, xmax, ymax = \
self.clip_bbox([xmin, ymin, xmax, ymax])
w = xmax - xmin
h = ymax - ymin
im_shape = t['im_shape'][0][i].tolist()
im_height, im_width = int(im_shape[0]), int(im_shape[1])
xmin *= im_width
ymin *= im_height
w *= im_width
h *= im_height
else:
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'category_id': catid,
'bbox': bbox,
'score': score
}
xywh_res.append(coco_res)
k += 1
return xywh_res
def _get_bbox_result(self, fetch_map, fetch_name, clsid2catid):
result = {}
is_bbox_normalized = False
output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']]
lengths = self._offset_to_lengths(lod)
np_data = np.array(output)
result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]])
bbox_results = self._bbox2out([result], clsid2catid, is_bbox_normalized)
return bbox_results
def color_map(self, num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = np.array(color_map).reshape(-1, 3)
return color_map
def draw_bbox(self, image, catid2name, bboxes, threshold, color_list):
"""
draw bbox on image
"""
draw = ImageDraw.Draw(image)
for dt in np.array(bboxes):
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
if score < threshold:
continue
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
color = tuple(color_list[catid])
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
# draw label
text = "{} {:.2f}".format(catid2name[catid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return image
def visualize(self, infer_img, bbox_results, catid2name, num_classes):
image = Image.open(infer_img).convert('RGB')
color_list = self.color_map(num_classes)
image = self.draw_bbox(image, self.catid2name, bbox_results, 0.5,
color_list)
image_path = os.path.split(infer_img)[-1]
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, image_path)
image.save(out_path, quality=95)
def __call__(self, image_with_bbox):
fetch_name = ""
for key in image_with_bbox:
if key == "image":
continue
if ".lod" in key:
continue
fetch_name = key
bbox_result = self._get_bbox_result(image_with_bbox, fetch_name,
self.clsid2catid)
if os.path.isdir(self.output_dir) is False:
os.mkdir(self.output_dir)
self.visualize(image_with_bbox["image"], bbox_result, self.catid2name,
len(self.label_list))
if os.path.isdir(self.output_dir) is False:
os.mkdir(self.output_dir)
bbox_file = os.path.join(self.output_dir, 'bbox.json')
with open(bbox_file, 'w') as f:
json.dump(bbox_result, f, indent=4)
def __repr__(self):
return self.__class__.__name__ + "label_file: {1}, output_dir: {2}".format(
self.label_file, self.output_dir)
class Sequential(object):
"""
Args:
......@@ -152,9 +308,10 @@ class Normalize(object):
"""
def __init__(self, mean, std):
def __init__(self, mean, std, channel_first=False):
self.mean = mean
self.std = std
self.channel_first = channel_first
def __call__(self, img):
"""
......@@ -164,7 +321,7 @@ class Normalize(object):
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(img, self.mean, self.std)
return F.normalize(img, self.mean, self.std, self.channel_first)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean,
......@@ -228,19 +385,21 @@ class Resize(object):
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
``None``
"""
def __init__(self, size, interpolation=None):
def __init__(self, size, max_size=2147483647, interpolation=None):
self.size = size
self.max_size = max_size
self.interpolation = interpolation
def __call__(self, img):
return F.resize(img, self.size, self.interpolation)
return F.resize(img, self.size, self.max_size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
self.size, _cv2_interpolation_to_str[self.interpolation])
return self.__class__.__name__ + '(size={0}, max_size={1}, interpolation={2})'.format(
self.size, self.max_size,
_cv2_interpolation_to_str[self.interpolation])
class Transpose(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册