未验证 提交 ab51c5c9 编写于 作者: Z zhiboniu 提交者: GitHub

Develop add hrnet mpii infer (#3747)

* add python train time eval

* add mpii infer support
上级 3c761465
...@@ -240,6 +240,8 @@ def draw_pose(imgfile, ...@@ -240,6 +240,8 @@ def draw_pose(imgfile,
raise e raise e
skeletons, scores = results['keypoint'] skeletons, scores = results['keypoint']
kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = skeletons.shape[1] kpt_nums = skeletons.shape[1]
if kpt_nums == 17: #plot coco keypoint if kpt_nums == 17: #plot coco keypoint
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
......
...@@ -83,7 +83,8 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -83,7 +83,8 @@ def get_categories(metric_type, anno_file=None, arch=None):
elif metric_type.lower() == 'widerface': elif metric_type.lower() == 'widerface':
return _widerface_category() return _widerface_category()
elif metric_type.lower() == 'keypointtopdowncocoeval': elif metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower(
) == 'keypointtopdownmpiieval':
return (None, {'id': 'keypoint'}) return (None, {'id': 'keypoint'})
elif metric_type.lower() in ['mot', 'motdet', 'reid']: elif metric_type.lower() in ['mot', 'motdet', 'reid']:
......
...@@ -292,11 +292,7 @@ class Trainer(object): ...@@ -292,11 +292,7 @@ class Trainer(object):
def train(self, validate=False): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
# if validation in training is enabled, metrics should be re-init
if validate:
self._init_metrics(validate=validate)
self._reset_metrics()
model = self.model model = self.model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
...@@ -394,6 +390,12 @@ class Trainer(object): ...@@ -394,6 +390,12 @@ class Trainer(object):
self._eval_dataset, self._eval_dataset,
self.cfg.worker_num, self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler) batch_sampler=self._eval_batch_sampler)
# if validation in training is enabled, metrics should be re-init
# Init_mark makes sure this code will only execute once
if validate and Init_mark == False:
Init_mark = True
self._init_metrics(validate=validate)
self._reset_metrics()
with paddle.no_grad(): with paddle.no_grad():
self.status['save_best_model'] = True self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader) self._eval_with_loader(self._eval_loader)
...@@ -558,9 +560,7 @@ class Trainer(object): ...@@ -558,9 +560,7 @@ class Trainer(object):
shape=[None, 3, 192, 64], name='crops') shape=[None, 3, 192, 64], name='crops')
}) })
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program # NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec # input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec( pruned_input_spec = self._prune_input_spec(
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import json import json
from collections import defaultdict from collections import defaultdict, OrderedDict
import numpy as np import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
......
...@@ -218,23 +218,35 @@ def draw_segm(image, ...@@ -218,23 +218,35 @@ def draw_segm(image,
return Image.fromarray(img_array.astype('uint8')) return Image.fromarray(img_array.astype('uint8'))
def map_coco_to_personlab(keypoints): def draw_pose(image,
permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3] results,
return keypoints[:, permute, :] visual_thread=0.6,
save_name='pose.jpg',
save_dir='output',
def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'): returnimg=False,
ids=None):
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
plt.switch_backend('agg') plt.switch_backend('agg')
except Exception as e: except Exception as e:
logger.error('Matplotlib not found, plaese install matplotlib.' logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.') 'for example: `pip install matplotlib`.')
raise e raise e
EDGES = [(0, 14), (0, 13), (0, 4), (0, 1), (14, 16), (13, 15), (4, 10),
(1, 7), (10, 11), (7, 8), (11, 12), (8, 9), (4, 5), (1, 2), (5, 6), skeletons = np.array([item['keypoints'] for item in results])
(2, 3)] kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = int(skeletons.shape[1] / 3)
skeletons = skeletons.reshape(-1, kpt_nums, 3)
if kpt_nums == 17: #plot coco keypoint
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
(13, 15), (14, 16), (11, 12)]
else: #plot mpii keypoint
EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8),
(8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12),
(8, 13)]
NUM_EDGES = len(EDGES) NUM_EDGES = len(EDGES)
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
...@@ -242,22 +254,36 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'): ...@@ -242,22 +254,36 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
cmap = matplotlib.cm.get_cmap('hsv') cmap = matplotlib.cm.get_cmap('hsv')
plt.figure() plt.figure()
skeletons = np.array([item['keypoints'] for item in results]).reshape(-1,
17, 3)
img = np.array(image).astype('float32') img = np.array(image).astype('float32')
canvas = img.copy()
for i in range(17): color_set = results['colors'] if 'colors' in results else None
rgba = np.array(cmap(1 - i / 17. - 1. / 34))
rgba[0:3] *= 255 if 'bbox' in results and ids is None:
bboxs = results['bbox']
for j, rect in enumerate(bboxs):
xmin, ymin, xmax, ymax = rect
color = colors[0] if color_set is None else colors[color_set[j] %
len(colors)]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
canvas = img.copy()
for i in range(kpt_nums):
for j in range(len(skeletons)): for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread: if skeletons[j][i, 2] < visual_thread:
continue continue
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.circle( cv2.circle(
canvas, canvas,
tuple(skeletons[j][i, 0:2].astype('int32')), tuple(skeletons[j][i, 0:2].astype('int32')),
2, 2,
colors[i], color,
thickness=-1) thickness=-1)
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0) to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
...@@ -265,7 +291,6 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'): ...@@ -265,7 +291,6 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
stickwidth = 2 stickwidth = 2
skeletons = map_coco_to_personlab(skeletons)
for i in range(NUM_EDGES): for i in range(NUM_EDGES):
for j in range(len(skeletons)): for j in range(len(skeletons)):
edge = EDGES[i] edge = EDGES[i]
...@@ -283,7 +308,13 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'): ...@@ -283,7 +308,13 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
polygon = cv2.ellipse2Poly((int(mY), int(mX)), polygon = cv2.ellipse2Poly((int(mY), int(mX)),
(int(length / 2), stickwidth), (int(length / 2), stickwidth),
int(angle), 0, 360, 1) int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.fillConvexPoly(cur_canvas, polygon, color)
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
image = Image.fromarray(canvas.astype('uint8')) image = Image.fromarray(canvas.astype('uint8'))
plt.close() plt.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册