未验证 提交 a2f3fd43 编写于 作者: Z zhiboniu 提交者: GitHub

pose bottomup higherhrnet: others (#2639)

* Add BottomUp HigherHRNet for pose.
* reverse xtcocotools to pycocotools
上级 22c6e80a
......@@ -26,7 +26,7 @@ logger = setup_logger(__name__)
__all__ = ['get_categories']
def get_categories(metric_type, anno_file=None):
def get_categories(metric_type, arch, anno_file=None):
"""
Get class id to category id map and category id
to category name map from annotation file.
......@@ -36,6 +36,9 @@ def get_categories(metric_type, anno_file=None):
and 'widerface'.
anno_file (str): annotation file path
"""
if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'})
if metric_type.lower() == 'coco':
if anno_file and os.path.isfile(anno_file):
# lazy import pycocotools here
......@@ -46,7 +49,6 @@ def get_categories(metric_type, anno_file=None):
clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
catid2name = {cat['id']: cat['name'] for cat in cats}
return clsid2catid, catid2name
# anno file not exist, load default categories of COCO17
......@@ -81,9 +83,6 @@ def get_categories(metric_type, anno_file=None):
elif metric_type.lower() == 'widerface':
return _widerface_category()
elif metric_type.lower().startswith('keypoint'):
return (None, {'id': 'keypoint'})
else:
raise ValueError("unknown metric type {}".format(metric_type))
......
......@@ -17,8 +17,8 @@ import cv2
import numpy as np
import copy
# TODO: unify xtococotools and pycocotools
import xtcocotools
from xtcocotools.coco import COCO
import pycocotools
from pycocotools.coco import COCO
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
......@@ -152,6 +152,8 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
self.dataset_name = 'coco'
cat_ids = self.coco.getCatIds()
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
print(f'=> num_images: {self.num_images}')
@staticmethod
......@@ -235,16 +237,16 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
for obj in anno:
if 'segmentation' in obj:
if obj['iscrowd']:
rle = xtcocotools.mask.frPyObjects(obj['segmentation'],
rle = pycocotools.mask.frPyObjects(obj['segmentation'],
img_info['height'],
img_info['width'])
m += xtcocotools.mask.decode(rle)
m += pycocotools.mask.decode(rle)
elif obj['num_keypoints'] == 0:
rles = xtcocotools.mask.frPyObjects(obj['segmentation'],
rles = pycocotools.mask.frPyObjects(obj['segmentation'],
img_info['height'],
img_info['width'])
for rle in rles:
m += xtcocotools.mask.decode(rle)
m += pycocotools.mask.decode(rle)
return m < 0.5
......
......@@ -37,15 +37,18 @@ TRT_MIN_SUBGRAPH = {
'TTFNet': 3,
'FCOS': 16,
'SOLOv2': 60,
'HigherHrnet': 40,
}
KEYPOINT_ARCH = ['HigherHrnet', 'Hrnet']
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = []
anno_file = dataset_cfg.get_anno()
clsid2catid, catid2name = get_categories(metric, anno_file)
clsid2catid, catid2name = get_categories(metric, arch, anno_file)
label_list = [str(cat) for cat in catid2name.values()]
......@@ -95,10 +98,13 @@ def _dump_infer_config(config, path, image_shape, model):
os._exit(0)
if 'Mask' in infer_arch:
infer_cfg['mask'] = True
label_arch = 'detection_arch'
if infer_arch in KEYPOINT_ARCH:
label_arch = 'keypoint_arch'
infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = _parse_reader(
config['TestReader'], config['TestDataset'], config['metric'],
infer_cfg['arch'], image_shape)
label_arch, image_shape)
yaml.dump(infer_cfg, open(path, 'w'))
logger.info("Export inference config file to {}".format(os.path.join(path)))
......
......@@ -351,7 +351,8 @@ class Trainer(object):
self._reset_metrics()
def evaluate(self):
self._eval_with_loader(self.loader)
with paddle.no_grad():
self._eval_with_loader(self.loader)
def predict(self,
images,
......@@ -376,7 +377,8 @@ class Trainer(object):
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key]
for key, value in outs.items():
outs[key] = value.numpy()
if hasattr(value, 'numpy'):
outs[key] = value.numpy()
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
......@@ -393,10 +395,12 @@ class Trainer(object):
if 'mask' in batch_res else None
segm_res = batch_res['segm'][start:end] \
if 'segm' in batch_res else None
keypoint_res = batch_res['keypoint'][start:end] \
if 'keypoint' in batch_res else None
image = visualize_results(image, bbox_res, mask_res, segm_res,
int(outs['im_id']), catid2name,
draw_threshold)
image = visualize_results(
image, bbox_res, mask_res, segm_res, keypoint_res,
int(outs['im_id']), catid2name, draw_threshold)
self.status['result_image'] = np.array(image.copy())
if self._compose_callback:
self._compose_callback.on_step_end(self.status)
......@@ -407,7 +411,13 @@ class Trainer(object):
image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
save_result(save_path, bbox_res, catid2name, draw_threshold)
results = {}
results["im_id"] = im_id
if bbox_res:
results["bbox_res"] = bbox_res
if keypoint_res:
results["keypoint_res"] = keypoint_res
save_result(save_path, results, catid2name, draw_threshold)
start = end
def _get_save_image_name(self, output_dir, image_path):
......@@ -435,6 +445,7 @@ class Trainer(object):
image_shape = [3, -1, -1]
self.model.eval()
if hasattr(self.model, 'deploy'): self.model.deploy = True
# Save infer cfg
_dump_infer_config(self.cfg,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from . import metrics
from .metrics import *
__all__ = metrics.__all__
......@@ -21,7 +21,7 @@ import sys
import numpy as np
import itertools
from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res, get_keypoint_res
from ppdet.metrics.map_utils import draw_pr_curve
from ppdet.utils.logger import setup_logger
......@@ -60,6 +60,10 @@ def get_infer_results(outs, catid, bias=0):
if 'segm' in outs:
infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid)
if 'keypoint' in outs:
infer_res['keypoint'] = get_keypoint_res(outs, im_id)
outs['bbox_num'] = [len(infer_res['keypoint'])]
return infer_res
......@@ -68,20 +72,30 @@ def cocoapi_eval(jsonfile,
coco_gt=None,
anno_file=None,
max_dets=(100, 300, 1000),
classwise=False):
classwise=False,
sigmas=None,
use_area=True):
"""
Args:
jsonfile (str): Evaluation json file, eg: bbox.json, mask.json.
style (str): COCOeval style, can be `bbox` , `segm` and `proposal`.
style (str): COCOeval style, can be `bbox` , `segm` , `proposal`, `keypoints` and `keypoints_crowd`.
coco_gt (str): Whether to load COCOAPI through anno_file,
eg: coco_gt = COCO(anno_file)
anno_file (str): COCO annotations file.
max_dets (tuple): COCO evaluation maxDets.
classwise (bool): Whether per-category AP and draw P-R Curve or not.
sigmas (nparray): keypoint labelling sigmas.
use_area (bool): If gt annotations (eg. CrowdPose, AIC)
do not have 'area', please set use_area=False.
"""
assert coco_gt != None or anno_file != None
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
if style == 'keypoints_crowd':
#please install xtcocotools==1.6
from xtcocotools.coco import COCO
from xtcocotools.cocoeval import COCOeval
else:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
if coco_gt == None:
coco_gt = COCO(anno_file)
......@@ -91,6 +105,9 @@ def cocoapi_eval(jsonfile,
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.params.useCats = 0
coco_eval.params.maxDets = list(max_dets)
elif style == 'keypoints_crowd':
coco_eval = COCOeval(coco_gt, coco_dt, style, sigmas, use_area)
coco_gt.anno_file.append("")
else:
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate()
......@@ -134,7 +151,7 @@ def cocoapi_eval(jsonfile,
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(
*[results_flatten[i::num_columns] for i in range(num_columns)])
* [results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
......
......@@ -149,3 +149,27 @@ def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map):
}
segm_res.append(coco_res)
return segm_res
def get_keypoint_res(results, im_id):
anns = []
preds = results['keypoint']
for idx in range(im_id.shape[0]):
image_id = im_id[idx].item()
kpts, scores = preds[idx]
for kpt, score in zip(kpts, scores):
kpt = kpt.flatten()
ann = {
'image_id': image_id,
'category_id': 1, # XXX hard code
'keypoints': kpt.tolist(),
'score': float(score)
}
x = kpt[0::3]
y = kpt[1::3]
x0, x1, y0, y1 = np.min(x).item(), np.max(x).item(), np.min(y).item(
), np.max(y).item()
ann['area'] = (x1 - x0) * (y1 - y0)
ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
anns.append(ann)
return anns
......@@ -34,6 +34,14 @@ __all__ = [
'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results'
]
COCO_SIGMAS = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87,
.89, .89
]) / 10.0
CROWD_SIGMAS = np.array(
[.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79,
.79]) / 10.0
class Metric(paddle.metric.Metric):
def name(self):
......@@ -70,11 +78,12 @@ class COCOMetric(Metric):
# TODO: bias should be unified
self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.iou_type = kwargs.get('IouType', 'bbox')
self.reset()
def reset(self):
# only bbox and mask evaluation support currently
self.results = {'bbox': [], 'mask': [], 'segm': []}
self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []}
self.eval_results = {}
def update(self, inputs, outputs):
......@@ -95,6 +104,8 @@ class COCOMetric(Metric):
'mask'] if 'mask' in infer_results else []
self.results['segm'] += infer_results[
'segm'] if 'segm' in infer_results else []
self.results['keypoint'] += infer_results[
'keypoint'] if 'keypoint' in infer_results else []
def accumulate(self):
if len(self.results['bbox']) > 0:
......@@ -157,6 +168,35 @@ class COCOMetric(Metric):
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
if len(self.results['keypoint']) > 0:
output = "keypoint.json"
if self.output_eval:
output = os.path.join(self.output_eval, output)
with open(output, 'w') as f:
json.dump(self.results['keypoint'], f)
logger.info('The keypoint result is saved to keypoint.json.')
if self.save_prediction_only:
logger.info('The keypoint result is saved to {} and do not '
'evaluate the mAP.'.format(output))
else:
style = 'keypoints'
use_area = True
sigmas = COCO_SIGMAS
if self.iou_type == 'keypoints_crowd':
style = 'keypoints_crowd'
use_area = False
sigmas = CROWD_SIGMAS
keypoint_stats = cocoapi_eval(
output,
style,
anno_file=self.anno_file,
classwise=self.classwise,
sigmas=sigmas,
use_area=use_area)
self.eval_results['keypoint'] = keypoint_stats
sys.stdout.flush()
def log(self):
pass
......
......@@ -38,9 +38,11 @@ class HigherHrnet(BaseArch):
hrhrnet_head='HigherHrnetHead',
post_process='HrHrnetPostProcess',
eval_flip=True,
flip_perm=None):
flip_perm=None,
max_num_people=30):
"""
HigherHrnet network, see https://arxiv.org/abs/
HigherHrnet network, see https://arxiv.org/abs/1908.10357;
HigherHrnet+swahr, see https://arxiv.org/abs/2012.15175
Args:
backbone (nn.Layer): backbone instance
......@@ -54,6 +56,9 @@ class HigherHrnet(BaseArch):
self.flip = eval_flip
self.flip_perm = paddle.to_tensor(flip_perm)
self.deploy = False
self.interpolate = L.Upsample(2, mode='bilinear')
self.pool = L.MaxPool(5, 1, 2)
self.max_num_people = max_num_people
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -71,7 +76,6 @@ class HigherHrnet(BaseArch):
}
def _forward(self):
batchsize = self.inputs['image'].shape[0]
if self.flip and not self.training and not self.deploy:
self.inputs['image'] = paddle.concat(
(self.inputs['image'], paddle.flip(self.inputs['image'], [3])))
......@@ -81,9 +85,7 @@ class HigherHrnet(BaseArch):
return self.hrhrnet_head(body_feats, self.inputs)
else:
outputs = self.hrhrnet_head(body_feats)
if self.deploy:
return outputs, [1]
if self.flip:
if self.flip and not self.deploy:
outputs = [paddle.split(o, 2) for o in outputs]
output_rflip = [
paddle.flip(paddle.gather(o[1], self.flip_perm, 1), [3])
......@@ -93,37 +95,69 @@ class HigherHrnet(BaseArch):
heatmap = (output1[0] + output_rflip[0]) / 2.
tagmaps = [output1[1], output_rflip[1]]
outputs = [heatmap] + tagmaps
outputs = self.get_topk(outputs)
res_lst = []
bboxnums = []
for idx in range(batchsize):
item = [o[idx:(idx + 1)] for o in outputs]
if self.deploy:
return outputs
h = self.inputs['im_shape'][idx, 0].numpy().item()
w = self.inputs['im_shape'][idx, 1].numpy().item()
kpts, scores = self.post_process(item, h, w)
res_lst.append([kpts, scores])
bboxnums.append(1)
res_lst = []
h = self.inputs['im_shape'][0, 0].numpy().item()
w = self.inputs['im_shape'][0, 1].numpy().item()
kpts, scores = self.post_process(*outputs, h, w)
res_lst.append([kpts, scores])
return res_lst, bboxnums
return res_lst
def get_loss(self):
return self._forward()
def get_pred(self):
outputs = {}
res_lst, bboxnums = self._forward()
res_lst = self._forward()
outputs['keypoint'] = res_lst
outputs['bbox_num'] = bboxnums
return outputs
def get_topk(self, outputs):
# resize to image size
outputs = [self.interpolate(x) for x in outputs]
if len(outputs) == 3:
tagmap = paddle.concat(
(outputs[1].unsqueeze(4), outputs[2].unsqueeze(4)), axis=4)
else:
tagmap = outputs[1].unsqueeze(4)
heatmap = outputs[0]
N, J = 1, self.hrhrnet_head.num_joints
heatmap_maxpool = self.pool(heatmap)
# topk
maxmap = heatmap * (heatmap == heatmap_maxpool)
maxmap = maxmap.reshape([N, J, -1])
heat_k, inds_k = maxmap.topk(self.max_num_people, axis=2)
outputs = [heatmap, tagmap, heat_k, inds_k]
return outputs
@register
@serializable
class HrHrnetPostProcess(object):
'''
HrHrnet postprocess contain:
1) get topk keypoints in the output heatmap
2) sample the tagmap's value corresponding to each of the topk coordinate
3) match different joints to combine to some people with Hungary algorithm
4) adjust the coordinate by +-0.25 to decrease error std
5) salvage missing joints by check positivity of heatmap - tagdiff_norm
Args:
max_num_people (int): max number of people support in postprocess
heat_thresh (float): value of topk below this threshhold will be ignored
tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init
inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
original_height, original_width (float): the original image size
'''
def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.):
self.interpolate = L.Upsample(2, mode='bilinear')
self.pool = L.MaxPool(5, 1, 2)
self.max_num_people = max_num_people
self.heat_thresh = heat_thresh
self.tag_thresh = tag_thresh
......@@ -140,25 +174,11 @@ class HrHrnetPostProcess(object):
-0.25)
return offset_y + 0.5, offset_x + 0.5
def __call__(self, inputs, original_height, original_width):
# resize to image size
inputs = [self.interpolate(x) for x in inputs]
# aggregate
heatmap = inputs[0]
if len(inputs) == 3:
tagmap = paddle.concat(
(inputs[1].unsqueeze(4), inputs[2].unsqueeze(4)), axis=4)
else:
tagmap = inputs[1].unsqueeze(4)
def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height,
original_width):
N, J, H, W = heatmap.shape
assert N == 1, "only support batch size 1"
# topk
maximum = self.pool(heatmap)
maxmap = heatmap * (heatmap == maximum)
maxmap = maxmap.reshape([N, J, -1])
heat_k, inds_k = maxmap.topk(self.max_num_people, axis=2)
heatmap = heatmap[0].cpu().detach().numpy()
tagmap = tagmap[0].cpu().detach().numpy()
heats = heat_k[0].cpu().detach().numpy()
......@@ -240,18 +260,10 @@ class HrHrnetPostProcess(object):
mean_score = pose_scores.mean(axis=1)
pose_kpts[valid, 2] = pose_scores[valid]
# TODO can we remove the outermost loop altogether
# salvage missing joints
if True:
for pid, coords in enumerate(pose_coords):
# vj = np.nonzero(valid[pid])[0]
# vyx = coords[valid[pid]].astype(np.int32)
# tag_mean = tagmap[vj, vyx[:, 0], vyx[:, 1]].mean(axis=0)
tag_mean = np.array(pose_tags[pid]).mean(
axis=0) #TODO: replace tagmap sample by history record
tag_mean = np.array(pose_tags[pid]).mean(axis=0)
norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5
score = heatmap - np.round(norm) # (J, H, W)
flat_score = score.reshape(J, -1)
......
......@@ -207,7 +207,6 @@ class OptimizerBuilder():
clip_norm=self.clip_grad_by_norm)
else:
grad_clip = None
if self.regularizer:
reg_type = self.regularizer['type'] + 'Decay'
reg_factor = self.regularizer['factor']
......
......@@ -89,7 +89,7 @@ DATASETS = {
'roadsign_coco': ([(
'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_coco.tar',
'49ce5a9b5ad0d6266163cd01de4b018e', ), ], ['annotations', 'images']),
'objects365': (),
'objects365': ()
}
DOWNLOAD_RETRY_LIMIT = 3
......
......@@ -20,6 +20,9 @@ from __future__ import unicode_literals
import numpy as np
from PIL import Image, ImageDraw
import cv2
import os
import math
from .colormap import colormap
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
......@@ -31,6 +34,7 @@ def visualize_results(image,
bbox_res,
mask_res,
segm_res,
keypoint_res,
im_id,
catid2name,
threshold=0.5):
......@@ -43,6 +47,8 @@ def visualize_results(image,
image = draw_mask(image, im_id, mask_res, threshold)
if segm_res is not None:
image = draw_segm(image, im_id, catid2name, segm_res, threshold)
if keypoint_res is not None:
image = draw_pose(image, keypoint_res, threshold)
return image
......@@ -124,21 +130,32 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold):
return image
def save_result(save_path, bbox_res, catid2name, threshold):
def save_result(save_path, results, catid2name, threshold):
"""
save result as txt
"""
img_id = int(results["im_id"])
with open(save_path, 'w') as f:
for dt in bbox_res:
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
if score < threshold:
continue
# each bbox result as a line
# for rbox: classname score x1 y1 x2 y2 x3 y3 x4 y4
# for bbox: classname score x1 y1 w h
bbox_pred = '{} {} '.format(catid2name[catid], score) + ' '.join(
[str(e) for e in bbox])
f.write(bbox_pred + '\n')
if "bbox_res" in results:
for dt in results["bbox_res"]:
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
if score < threshold:
continue
# each bbox result as a line
# for rbox: classname score x1 y1 x2 y2 x3 y3 x4 y4
# for bbox: classname score x1 y1 w h
bbox_pred = '{} {} '.format(catid2name[catid],
score) + ' '.join(
[str(e) for e in bbox])
f.write(bbox_pred + '\n')
elif "keypoint_res" in results:
for dt in results["keypoint_res"]:
kpts = dt['keypoints']
scores = dt['score']
keypoint_pred = [img_id, scores, kpts]
print(keypoint_pred, file=f)
else:
print("No valid results found, skip txt save")
def draw_segm(image,
......@@ -200,3 +217,77 @@ def draw_segm(image,
lineType=cv2.LINE_AA)
return Image.fromarray(img_array.astype('uint8'))
def map_coco_to_personlab(keypoints):
permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3]
return keypoints[:, permute, :]
def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
try:
import matplotlib.pyplot as plt
import matplotlib
plt.switch_backend('agg')
except Exception as e:
logger.error('Matplotlib not found, plaese install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
EDGES = [(0, 14), (0, 13), (0, 4), (0, 1), (14, 16), (13, 15), (4, 10),
(1, 7), (10, 11), (7, 8), (11, 12), (8, 9), (4, 5), (1, 2), (5, 6),
(2, 3)]
NUM_EDGES = len(EDGES)
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
cmap = matplotlib.cm.get_cmap('hsv')
plt.figure()
skeletons = np.array([item['keypoints'] for item in results]).reshape(-1,
17, 3)
scores = [item['score'] for item in results]
img = np.array(image).astype('float32')
canvas = img.copy()
for i in range(17):
rgba = np.array(cmap(1 - i / 17. - 1. / 34))
rgba[0:3] *= 255
for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread:
continue
cv2.circle(
canvas,
tuple(skeletons[j][i, 0:2].astype('int32')),
2,
colors[i],
thickness=-1)
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
fig = matplotlib.pyplot.gcf()
stickwidth = 2
skeletons = map_coco_to_personlab(skeletons)
for i in range(NUM_EDGES):
for j in range(len(skeletons)):
edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[
1], 2] < visual_thread:
continue
cur_canvas = canvas.copy()
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
image = Image.fromarray(canvas.astype('uint8'))
plt.close()
return image
......@@ -7,5 +7,5 @@ shapely
scipy
terminaltables
pycocotools
xtcocotools==1.6
#xtcocotools==1.6 #only for crowdpose
setuptools>=42.0.0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册