未验证 提交 83364301 编写于 作者: Z zhiboniu 提交者: GitHub

hrnet fix (#2920)

上级 03326eea
......@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 10
save_dir: output
snapshot_epoch: 10
weights: output/higherhrnet_hrnet_v1_512/model_final
weights: output/higherhrnet_hrnet_w32_512_swahr/model_final
epoch: 300
num_joints: &num_joints 17
flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
......
......@@ -2,7 +2,7 @@ use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 10
weights: output/hrnet_coco_256x192/50
weights: output/hrnet_coco_256x192/model_final
epoch: 210
num_joints: &num_joints 17
pixel_std: &pixel_std 200
......
......@@ -26,7 +26,7 @@ logger = setup_logger(__name__)
__all__ = ['get_categories']
def get_categories(metric_type, arch, anno_file=None):
def get_categories(metric_type, anno_file=None, arch=None):
"""
Get class id to category id map and category id
to category name map from annotation file.
......@@ -83,6 +83,9 @@ def get_categories(metric_type, arch, anno_file=None):
elif metric_type.lower() == 'widerface':
return _widerface_category()
elif metric_type.lower() == 'keypointtopdowncocoeval':
return (None, {'id': 'keypoint'})
else:
raise ValueError("unknown metric type {}".format(metric_type))
......
......@@ -39,7 +39,7 @@ registered_ops = []
__all__ = [
'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
'TopDownAffine', 'ToHeatmapsTopDown'
'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine'
]
......@@ -564,6 +564,38 @@ class TopDownAffine(object):
return records
@register_keypointop
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize):
self.trainsize = trainsize
def __call__(self, records):
image = records['image']
rot = 0
imshape = records['im_shape'][::-1]
center = imshape / 2.
scale = imshape
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
records['image'] = image
return records
@register_keypointop
class ToHeatmapsTopDown(object):
"""to generate the gaussin heatmaps of keypoint for heatmap loss
......
......@@ -49,7 +49,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
anno_file = dataset_cfg.get_anno()
clsid2catid, catid2name = get_categories(metric, arch, anno_file)
clsid2catid, catid2name = get_categories(metric, anno_file, arch)
label_list = [str(cat) for cat in catid2name.values()]
......
......@@ -392,6 +392,7 @@ class Trainer(object):
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
......
......@@ -56,13 +56,11 @@ class KeyPointTopDownCOCOEval(object):
self.idx = 0
def update(self, inputs, outputs):
kpt_coord = outputs['kpt_coord']
kpt_score = outputs['kpt_score']
kpts, _ = outputs['keypoint'][0]
num_images = inputs['image'].shape[0]
self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
2] = kpt_coord[:, :, 0:2]
self.results['all_preds'][self.idx:self.idx + num_images, :, 2:
3] = kpt_score
3] = kpts[:, :, 0:3]
self.results['all_boxes'][self.idx:self.idx + num_images, 0:2] = inputs[
'center'].numpy()[:, 0:2]
self.results['all_boxes'][self.idx:self.idx + num_images, 2:4] = inputs[
......@@ -115,7 +113,7 @@ class KeyPointTopDownCOCOEval(object):
result = [{
'image_id': img_kpts[k]['image'],
'category_id': cat_id,
'keypoints': list(_key_points[k]),
'keypoints': _key_points[k].tolist(),
'score': img_kpts[k]['score'],
'center': list(img_kpts[k]['center']),
'scale': list(img_kpts[k]['scale'])
......
......@@ -39,7 +39,7 @@ class TopDownHRNet(BaseArch):
loss='KeyPointMSELoss',
post_process='HRNetPostProcess',
flip_perm=None,
flip=False,
flip=True,
shift_heatmap=True):
"""
HRNnet network, see https://arxiv.org/abs/1902.09212
......@@ -57,6 +57,7 @@ class TopDownHRNet(BaseArch):
self.flip = flip
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
self.shift_heatmap = shift_heatmap
self.deploy = False
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -71,31 +72,37 @@ class TopDownHRNet(BaseArch):
if self.training:
return self.loss(hrnet_outputs, self.inputs)
elif self.deploy:
return hrnet_outputs
else:
if self.flip:
self.inputs['image'] = self.inputs['image'].flip([3])
feats = backbone(inputs)
output_flipped = self.final_conv(feats)
feats = self.backbone(self.inputs)
output_flipped = self.final_conv(feats[0])
output_flipped = self.flip_back(output_flipped.numpy(),
flip_perm)
self.flip_perm)
output_flipped = paddle.to_tensor(output_flipped.copy())
if self.shift_heatmap:
output_flipped[:, :, :, 1:] = output_flipped.clone(
)[:, :, :, 0:-1]
output = (output + output_flipped) * 0.5
preds, maxvals = self.post_process(hrnet_outputs, self.inputs)
return preds, maxvals
hrnet_outputs = (hrnet_outputs + output_flipped) * 0.5
imshape = (self.inputs['im_shape'].numpy()
)[:, ::-1] if 'im_shape' in self.inputs else None
center = self.inputs['center'].numpy(
) if 'center' in self.inputs else np.round(imshape / 2.)
scale = self.inputs['scale'].numpy(
) if 'scale' in self.inputs else imshape / 200.
outputs = self.post_process(hrnet_outputs, center, scale)
return outputs
def get_loss(self):
return self._forward()
def get_pred(self):
preds, maxvals = self._forward()
output = {'kpt_coord': preds, 'kpt_score': maxvals}
return output
res_lst = self._forward()
outputs = {'keypoint': res_lst}
return outputs
class HRNetPostProcess(object):
def flip_back(self, output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
......@@ -109,6 +116,8 @@ class HRNetPostProcess(object):
return output_flipped
class HRNetPostProcess(object):
def get_max_preds(self, heatmaps):
'''get predictions from score maps
......@@ -156,7 +165,7 @@ class HRNetPostProcess(object):
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
"""
coords, maxvals = self.get_max_preds(heatmaps)
......@@ -184,8 +193,11 @@ class HRNetPostProcess(object):
return preds, maxvals
def __call__(self, output, inputs):
preds, maxvals = self.get_final_preds(
output.numpy(), inputs['center'].numpy(), inputs['scale'].numpy())
return preds, maxvals
def __call__(self, output, center, scale):
preds, maxvals = self.get_final_preds(output.numpy(), center, scale)
outputs = [[
np.concatenate(
(preds, maxvals), axis=-1), np.mean(
maxvals, axis=1)
]]
return outputs
......@@ -246,7 +246,6 @@ def draw_pose(image, results, visual_thread=0.6, save_name='pose.jpg'):
skeletons = np.array([item['keypoints'] for item in results]).reshape(-1,
17, 3)
scores = [item['score'] for item in results]
img = np.array(image).astype('float32')
canvas = img.copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册