未验证 提交 912833f2 编写于 作者: Z zhiboniu 提交者: GitHub

change hrhrnet eval heat_thresh=0.1; fix crowdpose eval (#2945)

上级 1aa7e2e4
...@@ -60,7 +60,6 @@ class KeypointBottomUpBaseDataset(DetDataset): ...@@ -60,7 +60,6 @@ class KeypointBottomUpBaseDataset(DetDataset):
self.test_mode = test_mode self.test_mode = test_mode
self.ann_info['num_joints'] = num_joints self.ann_info['num_joints'] = num_joints
self.img_ids = [] self.img_ids = []
def __len__(self): def __len__(self):
......
...@@ -147,6 +147,7 @@ class Trainer(object): ...@@ -147,6 +147,7 @@ class Trainer(object):
eval_dataset.check_or_download_dataset() eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno() anno_file = eval_dataset.get_anno()
IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
self._metrics = [ self._metrics = [
COCOMetric( COCOMetric(
anno_file=anno_file, anno_file=anno_file,
...@@ -154,6 +155,7 @@ class Trainer(object): ...@@ -154,6 +155,7 @@ class Trainer(object):
classwise=classwise, classwise=classwise,
output_eval=output_eval, output_eval=output_eval,
bias=bias, bias=bias,
IouType=IouType,
save_prediction_only=save_prediction_only) save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'VOC': elif self.cfg.metric == 'VOC':
......
...@@ -107,7 +107,6 @@ def cocoapi_eval(jsonfile, ...@@ -107,7 +107,6 @@ def cocoapi_eval(jsonfile,
coco_eval.params.maxDets = list(max_dets) coco_eval.params.maxDets = list(max_dets)
elif style == 'keypoints_crowd': elif style == 'keypoints_crowd':
coco_eval = COCOeval(coco_gt, coco_dt, style, sigmas, use_area) coco_eval = COCOeval(coco_gt, coco_dt, style, sigmas, use_area)
coco_gt.anno_file.append("")
else: else:
coco_eval = COCOeval(coco_gt, coco_dt, style) coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate() coco_eval.evaluate()
......
...@@ -52,7 +52,7 @@ class HigherHRNet(BaseArch): ...@@ -52,7 +52,7 @@ class HigherHRNet(BaseArch):
super(HigherHRNet, self).__init__() super(HigherHRNet, self).__init__()
self.backbone = backbone self.backbone = backbone
self.hrhrnet_head = hrhrnet_head self.hrhrnet_head = hrhrnet_head
self.post_process = HrHRNetPostProcess() self.post_process = post_process
self.flip = eval_flip self.flip = eval_flip
self.flip_perm = paddle.to_tensor(flip_perm) self.flip_perm = paddle.to_tensor(flip_perm)
self.deploy = False self.deploy = False
...@@ -85,6 +85,7 @@ class HigherHRNet(BaseArch): ...@@ -85,6 +85,7 @@ class HigherHRNet(BaseArch):
return self.hrhrnet_head(body_feats, self.inputs) return self.hrhrnet_head(body_feats, self.inputs)
else: else:
outputs = self.hrhrnet_head(body_feats) outputs = self.hrhrnet_head(body_feats)
if self.flip and not self.deploy: if self.flip and not self.deploy:
outputs = [paddle.split(o, 2) for o in outputs] outputs = [paddle.split(o, 2) for o in outputs]
output_rflip = [ output_rflip = [
...@@ -105,7 +106,6 @@ class HigherHRNet(BaseArch): ...@@ -105,7 +106,6 @@ class HigherHRNet(BaseArch):
w = self.inputs['im_shape'][0, 1].numpy().item() w = self.inputs['im_shape'][0, 1].numpy().item()
kpts, scores = self.post_process(*outputs, h, w) kpts, scores = self.post_process(*outputs, h, w)
res_lst.append([kpts, scores]) res_lst.append([kpts, scores])
return res_lst return res_lst
def get_loss(self): def get_loss(self):
...@@ -157,7 +157,7 @@ class HrHRNetPostProcess(object): ...@@ -157,7 +157,7 @@ class HrHRNetPostProcess(object):
original_height, original_width (float): the original image size original_height, original_width (float): the original image size
''' '''
def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.): def __init__(self, max_num_people=30, heat_thresh=0.1, tag_thresh=1.):
self.max_num_people = max_num_people self.max_num_people = max_num_people
self.heat_thresh = heat_thresh self.heat_thresh = heat_thresh
self.tag_thresh = tag_thresh self.tag_thresh = tag_thresh
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册