未验证 提交 cc266e54 编写于 作者: J JYChen 提交者: GitHub

add --save_prediction_only support for TopDown KeyPoint Metric (#3865)

* add --save_prediction_only support for TopDown KeyPoint Metric

* add a use case for save_prediction_only
上级 ed0cd8da
...@@ -75,6 +75,9 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/hig ...@@ -75,6 +75,9 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/hig
#MPII DataSet #MPII DataSet
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/hrnet/hrnet_w32_256x256_mpii.yml
#当只需要保存评估预测的结果时,可以通过设置save_prediction_only参数实现,评估预测结果默认保存在output/keypoints_results.json文件中
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml --save_prediction_only
``` ```
**模型预测:** **模型预测:**
......
...@@ -227,19 +227,27 @@ class Trainer(object): ...@@ -227,19 +227,27 @@ class Trainer(object):
eval_dataset = self.cfg['EvalDataset'] eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset() eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno() anno_file = eval_dataset.get_anno()
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [ self._metrics = [
KeyPointTopDownCOCOEval(anno_file, KeyPointTopDownCOCOEval(
len(eval_dataset), self.cfg.num_joints, anno_file,
self.cfg.save_dir) len(eval_dataset),
self.cfg.num_joints,
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'KeyPointTopDownMPIIEval': elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
eval_dataset = self.cfg['EvalDataset'] eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset() eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno() anno_file = eval_dataset.get_anno()
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [ self._metrics = [
KeyPointTopDownMPIIEval(anno_file, KeyPointTopDownMPIIEval(
len(eval_dataset), self.cfg.num_joints, anno_file,
self.cfg.save_dir) len(eval_dataset),
self.cfg.num_joints,
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'MOTDet': elif self.cfg.metric == 'MOTDet':
self._metrics = [JDEDetMetric(), ] self._metrics = [JDEDetMetric(), ]
......
...@@ -20,6 +20,8 @@ from pycocotools.coco import COCO ...@@ -20,6 +20,8 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from ..modeling.keypoint_utils import oks_nms from ..modeling.keypoint_utils import oks_nms
from scipy.io import loadmat, savemat from scipy.io import loadmat, savemat
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval'] __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
...@@ -38,7 +40,8 @@ class KeyPointTopDownCOCOEval(object): ...@@ -38,7 +40,8 @@ class KeyPointTopDownCOCOEval(object):
output_eval, output_eval,
iou_type='keypoints', iou_type='keypoints',
in_vis_thre=0.2, in_vis_thre=0.2,
oks_thre=0.9): oks_thre=0.9,
save_prediction_only=False):
super(KeyPointTopDownCOCOEval, self).__init__() super(KeyPointTopDownCOCOEval, self).__init__()
self.coco = COCO(anno_file) self.coco = COCO(anno_file)
self.num_samples = num_samples self.num_samples = num_samples
...@@ -48,6 +51,7 @@ class KeyPointTopDownCOCOEval(object): ...@@ -48,6 +51,7 @@ class KeyPointTopDownCOCOEval(object):
self.oks_thre = oks_thre self.oks_thre = oks_thre
self.output_eval = output_eval self.output_eval = output_eval
self.res_file = os.path.join(output_eval, "keypoints_results.json") self.res_file = os.path.join(output_eval, "keypoints_results.json")
self.save_prediction_only = save_prediction_only
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -90,6 +94,7 @@ class KeyPointTopDownCOCOEval(object): ...@@ -90,6 +94,7 @@ class KeyPointTopDownCOCOEval(object):
os.makedirs(self.output_eval) os.makedirs(self.output_eval)
with open(self.res_file, 'w') as f: with open(self.res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4) json.dump(results, f, sort_keys=True, indent=4)
logger.info(f'The keypoint result is saved to {self.res_file}.')
try: try:
json.load(open(self.res_file)) json.load(open(self.res_file))
except Exception: except Exception:
...@@ -178,6 +183,10 @@ class KeyPointTopDownCOCOEval(object): ...@@ -178,6 +183,10 @@ class KeyPointTopDownCOCOEval(object):
self.get_final_results(self.results['all_preds'], self.get_final_results(self.results['all_preds'],
self.results['all_boxes'], self.results['all_boxes'],
self.results['image_path']) self.results['image_path'])
if self.save_prediction_only:
logger.info(f'The keypoint result is saved to {self.res_file} '
'and do not evaluate the mAP.')
return
coco_dt = self.coco.loadRes(self.res_file) coco_dt = self.coco.loadRes(self.res_file)
coco_eval = COCOeval(self.coco, coco_dt, 'keypoints') coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
coco_eval.params.useSegm = None coco_eval.params.useSegm = None
...@@ -191,6 +200,8 @@ class KeyPointTopDownCOCOEval(object): ...@@ -191,6 +200,8 @@ class KeyPointTopDownCOCOEval(object):
self.eval_results['keypoint'] = keypoint_stats self.eval_results['keypoint'] = keypoint_stats
def log(self): def log(self):
if self.save_prediction_only:
return
stats_names = [ stats_names = [
'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
'AR .75', 'AR (M)', 'AR (L)' 'AR .75', 'AR (M)', 'AR (L)'
...@@ -213,9 +224,12 @@ class KeyPointTopDownMPIIEval(object): ...@@ -213,9 +224,12 @@ class KeyPointTopDownMPIIEval(object):
num_samples, num_samples,
num_joints, num_joints,
output_eval, output_eval,
oks_thre=0.9): oks_thre=0.9,
save_prediction_only=False):
super(KeyPointTopDownMPIIEval, self).__init__() super(KeyPointTopDownMPIIEval, self).__init__()
self.ann_file = anno_file self.ann_file = anno_file
self.res_file = os.path.join(output_eval, "keypoints_results.json")
self.save_prediction_only = save_prediction_only
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -239,9 +253,32 @@ class KeyPointTopDownMPIIEval(object): ...@@ -239,9 +253,32 @@ class KeyPointTopDownMPIIEval(object):
self.results.append(results) self.results.append(results)
def accumulate(self): def accumulate(self):
self._mpii_keypoint_results_save()
if self.save_prediction_only:
logger.info(f'The keypoint result is saved to {self.res_file} '
'and do not evaluate the mAP.')
return
self.eval_results = self.evaluate(self.results) self.eval_results = self.evaluate(self.results)
def _mpii_keypoint_results_save(self):
results = []
for res in self.results:
if len(res) == 0:
continue
result = [{
'preds': res['preds'][k].tolist(),
'boxes': res['boxes'][k].tolist(),
'image_path': res['image_path'][k],
} for k in range(len(res))]
results.extend(result)
with open(self.res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)
logger.info(f'The keypoint result is saved to {self.res_file}.')
def log(self): def log(self):
if self.save_prediction_only:
return
for item, value in self.eval_results.items(): for item, value in self.eval_results.items():
print("{} : {}".format(item, value)) print("{} : {}".format(item, value))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册