未验证 提交 effbad5d 编写于 作者: S shangliang Xu 提交者: GitHub

add save only in eval (#2608)

上级 78fdd43e
...@@ -116,9 +116,13 @@ class Trainer(object): ...@@ -116,9 +116,13 @@ class Trainer(object):
if self.cfg.metric == 'COCO': if self.cfg.metric == 'COCO':
# TODO: bias should be unified # TODO: bias should be unified
bias = self.cfg['bias'] if 'bias' in self.cfg else 0 bias = self.cfg['bias'] if 'bias' in self.cfg else 0
save_prediction_only = self.cfg['save_prediction_only'] \
if 'save_prediction_only' in self.cfg else False
self._metrics = [ self._metrics = [
COCOMetric( COCOMetric(
anno_file=self.dataset.get_anno(), bias=bias) anno_file=self.dataset.get_anno(),
bias=bias,
save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'VOC': elif self.cfg.metric == 'VOC':
self._metrics = [ self._metrics = [
......
...@@ -65,6 +65,7 @@ class COCOMetric(Metric): ...@@ -65,6 +65,7 @@ class COCOMetric(Metric):
self.clsid2catid, self.catid2name = get_categories('COCO', anno_file) self.clsid2catid, self.catid2name = get_categories('COCO', anno_file)
# TODO: bias should be unified # TODO: bias should be unified
self.bias = kwargs.get('bias', 0) self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -97,30 +98,42 @@ class COCOMetric(Metric): ...@@ -97,30 +98,42 @@ class COCOMetric(Metric):
json.dump(self.results['bbox'], f) json.dump(self.results['bbox'], f)
logger.info('The bbox result is saved to bbox.json.') logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval( if self.save_prediction_only:
'bbox.json', 'bbox', anno_file=self.anno_file) logger.info('The bbox result is saved to bbox.json and do not '
self.eval_results['bbox'] = bbox_stats 'evaluate the mAP.')
sys.stdout.flush() else:
bbox_stats = cocoapi_eval(
'bbox.json', 'bbox', anno_file=self.anno_file)
self.eval_results['bbox'] = bbox_stats
sys.stdout.flush()
if len(self.results['mask']) > 0: if len(self.results['mask']) > 0:
with open("mask.json", 'w') as f: with open("mask.json", 'w') as f:
json.dump(self.results['mask'], f) json.dump(self.results['mask'], f)
logger.info('The mask result is saved to mask.json.') logger.info('The mask result is saved to mask.json.')
seg_stats = cocoapi_eval( if self.save_prediction_only:
'mask.json', 'segm', anno_file=self.anno_file) logger.info('The mask result is saved to mask.json and do not '
self.eval_results['mask'] = seg_stats 'evaluate the mAP.')
sys.stdout.flush() else:
seg_stats = cocoapi_eval(
'mask.json', 'segm', anno_file=self.anno_file)
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
if len(self.results['segm']) > 0: if len(self.results['segm']) > 0:
with open("segm.json", 'w') as f: with open("segm.json", 'w') as f:
json.dump(self.results['segm'], f) json.dump(self.results['segm'], f)
logger.info('The segm result is saved to segm.json.') logger.info('The segm result is saved to segm.json.')
seg_stats = cocoapi_eval( if self.save_prediction_only:
'segm.json', 'segm', anno_file=self.anno_file) logger.info('The segm result is saved to segm.json and do not '
self.eval_results['mask'] = seg_stats 'evaluate the mAP.')
sys.stdout.flush() else:
seg_stats = cocoapi_eval(
'segm.json', 'segm', anno_file=self.anno_file)
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
def log(self): def log(self):
pass pass
......
...@@ -64,6 +64,12 @@ def parse_args(): ...@@ -64,6 +64,12 @@ def parse_args():
action="store_true", action="store_true",
help="whether add bias or not while getting w and h") help="whether add bias or not while getting w and h")
parser.add_argument(
'--save_prediction_only',
action='store_true',
default=False,
help='Whether to save the evaluation results only')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -88,6 +94,7 @@ def main(): ...@@ -88,6 +94,7 @@ def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
# TODO: bias should be unified # TODO: bias should be unified
cfg['bias'] = 1 if FLAGS.bias else 0 cfg['bias'] = 1 if FLAGS.bias else 0
cfg['save_prediction_only'] = FLAGS.save_prediction_only
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
......
...@@ -65,6 +65,11 @@ def parse_args(): ...@@ -65,6 +65,11 @@ def parse_args():
default=False, default=False,
help="If set True, enable continuous evaluation job." help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.") "This flag is only used for internal test.")
parser.add_argument(
'--save_prediction_only',
action='store_true',
default=False,
help='Whether to save the evaluation results only')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -91,6 +96,7 @@ def main(): ...@@ -91,6 +96,7 @@ def main():
FLAGS = parse_args() FLAGS = parse_args()
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
cfg['save_prediction_only'] = FLAGS.save_prediction_only
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.slim_config: if FLAGS.slim_config:
slim_cfg = load_config(FLAGS.slim_config) slim_cfg = load_config(FLAGS.slim_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册