未验证 提交 83364301 编写于 作者: Z zhiboniu 提交者: GitHub

hrnet fix (#2920)

上级 03326eea
...@@ -2,7 +2,7 @@ use_gpu: true ...@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 10 log_iter: 10
save_dir: output save_dir: output
snapshot_epoch: 10 snapshot_epoch: 10
weights: output/higherhrnet_hrnet_v1_512/model_final weights: output/higherhrnet_hrnet_w32_512_swahr/model_final
epoch: 300 epoch: 300
num_joints: &num_joints 17 num_joints: &num_joints 17
flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
......
...@@ -2,7 +2,7 @@ use_gpu: true ...@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 5 log_iter: 5
save_dir: output save_dir: output
snapshot_epoch: 10 snapshot_epoch: 10
weights: output/hrnet_coco_256x192/50 weights: output/hrnet_coco_256x192/model_final
epoch: 210 epoch: 210
num_joints: &num_joints 17 num_joints: &num_joints 17
pixel_std: &pixel_std 200 pixel_std: &pixel_std 200
......
...@@ -26,7 +26,7 @@ logger = setup_logger(__name__) ...@@ -26,7 +26,7 @@ logger = setup_logger(__name__)
__all__ = ['get_categories'] __all__ = ['get_categories']
def get_categories(metric_type, arch, anno_file=None): def get_categories(metric_type, anno_file=None, arch=None):
""" """
Get class id to category id map and category id Get class id to category id map and category id
to category name map from annotation file. to category name map from annotation file.
...@@ -83,6 +83,9 @@ def get_categories(metric_type, arch, anno_file=None): ...@@ -83,6 +83,9 @@ def get_categories(metric_type, arch, anno_file=None):
elif metric_type.lower() == 'widerface': elif metric_type.lower() == 'widerface':
return _widerface_category() return _widerface_category()
elif metric_type.lower() == 'keypointtopdowncocoeval':
return (None, {'id': 'keypoint'})
else: else:
raise ValueError("unknown metric type {}".format(metric_type)) raise ValueError("unknown metric type {}".format(metric_type))
......
...@@ -39,7 +39,7 @@ registered_ops = [] ...@@ -39,7 +39,7 @@ registered_ops = []
__all__ = [ __all__ = [
'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps', 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform', 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
'TopDownAffine', 'ToHeatmapsTopDown' 'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine'
] ]
...@@ -564,6 +564,38 @@ class TopDownAffine(object): ...@@ -564,6 +564,38 @@ class TopDownAffine(object):
return records return records
@register_keypointop
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize):
self.trainsize = trainsize
def __call__(self, records):
image = records['image']
rot = 0
imshape = records['im_shape'][::-1]
center = imshape / 2.
scale = imshape
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
records['image'] = image
return records
@register_keypointop @register_keypointop
class ToHeatmapsTopDown(object): class ToHeatmapsTopDown(object):
"""to generate the gaussin heatmaps of keypoint for heatmap loss """to generate the gaussin heatmaps of keypoint for heatmap loss
......
...@@ -49,7 +49,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): ...@@ -49,7 +49,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
anno_file = dataset_cfg.get_anno() anno_file = dataset_cfg.get_anno()
clsid2catid, catid2name = get_categories(metric, arch, anno_file) clsid2catid, catid2name = get_categories(metric, anno_file, arch)
label_list = [str(cat) for cat in catid2name.values()] label_list = [str(cat) for cat in catid2name.values()]
......
...@@ -392,6 +392,7 @@ class Trainer(object): ...@@ -392,6 +392,7 @@ class Trainer(object):
batch_res = get_infer_results(outs, clsid2catid) batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num'] bbox_num = outs['bbox_num']
start = 0 start = 0
for i, im_id in enumerate(outs['im_id']): for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)] image_path = imid2path[int(im_id)]
......
...@@ -56,13 +56,11 @@ class KeyPointTopDownCOCOEval(object): ...@@ -56,13 +56,11 @@ class KeyPointTopDownCOCOEval(object):
self.idx = 0 self.idx = 0
def update(self, inputs, outputs): def update(self, inputs, outputs):
kpt_coord = outputs['kpt_coord'] kpts, _ = outputs['keypoint'][0]
kpt_score = outputs['kpt_score']
num_images = inputs['image'].shape[0] num_images = inputs['image'].shape[0]
self.results['all_preds'][self.idx:self.idx + num_images, :, 0: self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
2] = kpt_coord[:, :, 0:2] 3] = kpts[:, :, 0:3]
self.results['all_preds'][self.idx:self.idx + num_images, :, 2:
3] = kpt_score
self.results['all_boxes'][self.idx:self.idx + num_images, 0:2] = inputs[ self.results['all_boxes'][self.idx:self.idx + num_images, 0:2] = inputs[
'center'].numpy()[:, 0:2] 'center'].numpy()[:, 0:2]
self.results['all_boxes'][self.idx:self.idx + num_images, 2:4] = inputs[ self.results['all_boxes'][self.idx:self.idx + num_images, 2:4] = inputs[
...@@ -115,7 +113,7 @@ class KeyPointTopDownCOCOEval(object): ...@@ -115,7 +113,7 @@ class KeyPointTopDownCOCOEval(object):
result = [{ result = [{
'image_id': img_kpts[k]['image'], 'image_id': img_kpts[k]['image'],
'category_id': cat_id, 'category_id': cat_id,
'keypoints': list(_key_points[k]), 'keypoints': _key_points[k].tolist(),
'score': img_kpts[k]['score'], 'score': img_kpts[k]['score'],
'center': list(img_kpts[k]['center']), 'center': list(img_kpts[k]['center']),
'scale': list(img_kpts[k]['scale']) 'scale': list(img_kpts[k]['scale'])
......
...@@ -39,7 +39,7 @@ class TopDownHRNet(BaseArch): ...@@ -39,7 +39,7 @@ class TopDownHRNet(BaseArch):
loss='KeyPointMSELoss', loss='KeyPointMSELoss',
post_process='HRNetPostProcess', post_process='HRNetPostProcess',
flip_perm=None, flip_perm=None,
flip=False, flip=True,
shift_heatmap=True): shift_heatmap=True):
""" """
HRNnet network, see https://arxiv.org/abs/1902.09212 HRNnet network, see https://arxiv.org/abs/1902.09212
...@@ -57,6 +57,7 @@ class TopDownHRNet(BaseArch): ...@@ -57,6 +57,7 @@ class TopDownHRNet(BaseArch):
self.flip = flip self.flip = flip
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True) self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
self.shift_heatmap = shift_heatmap self.shift_heatmap = shift_heatmap
self.deploy = False
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -71,31 +72,37 @@ class TopDownHRNet(BaseArch): ...@@ -71,31 +72,37 @@ class TopDownHRNet(BaseArch):
if self.training: if self.training:
return self.loss(hrnet_outputs, self.inputs) return self.loss(hrnet_outputs, self.inputs)
elif self.deploy:
return hrnet_outputs
else: else:
if self.flip: if self.flip:
self.inputs['image'] = self.inputs['image'].flip([3]) self.inputs['image'] = self.inputs['image'].flip([3])
feats = backbone(inputs) feats = self.backbone(self.inputs)
output_flipped = self.final_conv(feats) output_flipped = self.final_conv(feats[0])
output_flipped = self.flip_back(output_flipped.numpy(), output_flipped = self.flip_back(output_flipped.numpy(),
flip_perm) self.flip_perm)
output_flipped = paddle.to_tensor(output_flipped.copy()) output_flipped = paddle.to_tensor(output_flipped.copy())
if self.shift_heatmap: if self.shift_heatmap:
output_flipped[:, :, :, 1:] = output_flipped.clone( output_flipped[:, :, :, 1:] = output_flipped.clone(
)[:, :, :, 0:-1] )[:, :, :, 0:-1]
output = (output + output_flipped) * 0.5 hrnet_outputs = (hrnet_outputs + output_flipped) * 0.5
preds, maxvals = self.post_process(hrnet_outputs, self.inputs) imshape = (self.inputs['im_shape'].numpy()
return preds, maxvals )[:, ::-1] if 'im_shape' in self.inputs else None
center = self.inputs['center'].numpy(
) if 'center' in self.inputs else np.round(imshape / 2.)
scale = self.inputs['scale'].numpy(
) if 'scale' in self.inputs else imshape / 200.
outputs = self.post_process(hrnet_outputs, center, scale)
return outputs
def get_loss(self): def get_loss(self):
return self._forward() return self._forward()
def get_pred(self): def get_pred(self):
preds, maxvals = self._forward() res_lst = self._forward()
output = {'kpt_coord': preds, 'kpt_score': maxvals} outputs = {'keypoint': res_lst}
return output return outputs
class HRNetPostProcess(object):
def flip_back(self, output_flipped, matched_parts): def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\ assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]' 'output_flipped should be [batch_size, num_joints, height, width]'
...@@ -109,6 +116,8 @@ class HRNetPostProcess(object): ...@@ -109,6 +116,8 @@ class HRNetPostProcess(object):
return output_flipped return output_flipped
class HRNetPostProcess(object):
def get_max_preds(self, heatmaps): def get_max_preds(self, heatmaps):
'''get predictions from score maps '''get predictions from score maps
...@@ -156,7 +165,7 @@ class HRNetPostProcess(object): ...@@ -156,7 +165,7 @@ class HRNetPostProcess(object):
Returns: Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
""" """
coords, maxvals = self.get_max_preds(heatmaps) coords, maxvals = self.get_max_preds(heatmaps)
...@@ -184,8 +193,11 @@ class HRNetPostProcess(object): ...@@ -184,8 +193,11 @@ class HRNetPostProcess(object):
return preds, maxvals return preds, maxvals
def __call__(self, output, inputs): def __call__(self, output, center, scale):
preds, maxvals = self.get_final_preds( preds, maxvals = self.get_final_preds(output.numpy(), center, scale)
output.numpy(), inputs['center'].numpy(), inputs['scale'].numpy()) outputs = [[
np.concatenate(
return preds, maxvals (preds, maxvals), axis=-1), np.mean(
maxvals, axis=1)
]]
return outputs
...@@ -246,7 +246,6 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'): ...@@ -246,7 +246,6 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
skeletons = np.array([item['keypoints'] for item in results]).reshape(-1, skeletons = np.array([item['keypoints'] for item in results]).reshape(-1,
17, 3) 17, 3)
scores = [item['score'] for item in results]
img = np.array(image).astype('float32') img = np.array(image).astype('float32')
canvas = img.copy() canvas = img.copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册