未验证 提交 62d82636 编写于 作者: S shangliang Xu 提交者: GitHub

add save only in eval/train test=develop (#2604)

上级 bd7adb05
...@@ -125,6 +125,8 @@ class Trainer(object): ...@@ -125,6 +125,8 @@ class Trainer(object):
bias = self.cfg['bias'] if 'bias' in self.cfg else 0 bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \ output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg['save_prediction_only'] \
if 'save_prediction_only' in self.cfg else False
# pass clsid2catid info to metric instance to avoid multiple loading # pass clsid2catid info to metric instance to avoid multiple loading
# annotation file # annotation file
...@@ -145,7 +147,8 @@ class Trainer(object): ...@@ -145,7 +147,8 @@ class Trainer(object):
clsid2catid=clsid2catid, clsid2catid=clsid2catid,
classwise=classwise, classwise=classwise,
output_eval=output_eval, output_eval=output_eval,
bias=bias) bias=bias,
save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'VOC': elif self.cfg.metric == 'VOC':
self._metrics = [ self._metrics = [
......
...@@ -69,6 +69,7 @@ class COCOMetric(Metric): ...@@ -69,6 +69,7 @@ class COCOMetric(Metric):
self.output_eval = kwargs.get('output_eval', None) self.output_eval = kwargs.get('output_eval', None)
# TODO: bias should be unified # TODO: bias should be unified
self.bias = kwargs.get('bias', 0) self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -104,13 +105,17 @@ class COCOMetric(Metric): ...@@ -104,13 +105,17 @@ class COCOMetric(Metric):
json.dump(self.results['bbox'], f) json.dump(self.results['bbox'], f)
logger.info('The bbox result is saved to bbox.json.') logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval( if self.save_prediction_only:
output, logger.info('The bbox result is saved to {} and do not '
'bbox', 'evaluate the mAP.'.format(output))
anno_file=self.anno_file, else:
classwise=self.classwise) bbox_stats = cocoapi_eval(
self.eval_results['bbox'] = bbox_stats output,
sys.stdout.flush() 'bbox',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['bbox'] = bbox_stats
sys.stdout.flush()
if len(self.results['mask']) > 0: if len(self.results['mask']) > 0:
output = "mask.json" output = "mask.json"
...@@ -120,13 +125,17 @@ class COCOMetric(Metric): ...@@ -120,13 +125,17 @@ class COCOMetric(Metric):
json.dump(self.results['mask'], f) json.dump(self.results['mask'], f)
logger.info('The mask result is saved to mask.json.') logger.info('The mask result is saved to mask.json.')
seg_stats = cocoapi_eval( if self.save_prediction_only:
output, logger.info('The mask result is saved to {} and do not '
'segm', 'evaluate the mAP.'.format(output))
anno_file=self.anno_file, else:
classwise=self.classwise) seg_stats = cocoapi_eval(
self.eval_results['mask'] = seg_stats output,
sys.stdout.flush() 'segm',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
if len(self.results['segm']) > 0: if len(self.results['segm']) > 0:
output = "segm.json" output = "segm.json"
...@@ -136,13 +145,17 @@ class COCOMetric(Metric): ...@@ -136,13 +145,17 @@ class COCOMetric(Metric):
json.dump(self.results['segm'], f) json.dump(self.results['segm'], f)
logger.info('The segm result is saved to segm.json.') logger.info('The segm result is saved to segm.json.')
seg_stats = cocoapi_eval( if self.save_prediction_only:
output, logger.info('The segm result is saved to {} and do not '
'segm', 'evaluate the mAP.'.format(output))
anno_file=self.anno_file, else:
classwise=self.classwise) seg_stats = cocoapi_eval(
self.eval_results['mask'] = seg_stats output,
sys.stdout.flush() 'segm',
anno_file=self.anno_file,
classwise=self.classwise)
self.eval_results['mask'] = seg_stats
sys.stdout.flush()
def log(self): def log(self):
pass pass
......
...@@ -66,6 +66,12 @@ def parse_args(): ...@@ -66,6 +66,12 @@ def parse_args():
action="store_true", action="store_true",
help="whether per-category AP and draw P-R Curve or not.") help="whether per-category AP and draw P-R Curve or not.")
parser.add_argument(
'--save_prediction_only',
action='store_true',
default=False,
help='Whether to save the evaluation results only')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -85,7 +91,7 @@ def run(FLAGS, cfg): ...@@ -85,7 +91,7 @@ def run(FLAGS, cfg):
# init parallel environment if nranks > 1 # init parallel environment if nranks > 1
init_parallel_env() init_parallel_env()
# build trainer # build trainer
trainer = Trainer(cfg, mode='eval') trainer = Trainer(cfg, mode='eval')
# load weights # load weights
...@@ -102,6 +108,7 @@ def main(): ...@@ -102,6 +108,7 @@ def main():
cfg['bias'] = 1 if FLAGS.bias else 0 cfg['bias'] = 1 if FLAGS.bias else 0
cfg['classwise'] = True if FLAGS.classwise else False cfg['classwise'] = True if FLAGS.classwise else False
cfg['output_eval'] = FLAGS.output_eval cfg['output_eval'] = FLAGS.output_eval
cfg['save_prediction_only'] = FLAGS.save_prediction_only
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu') place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
......
...@@ -75,6 +75,11 @@ def parse_args(): ...@@ -75,6 +75,11 @@ def parse_args():
type=str, type=str,
default="vdl_log_dir/scalar", default="vdl_log_dir/scalar",
help='VisualDL logging directory for scalar.') help='VisualDL logging directory for scalar.')
parser.add_argument(
'--save_prediction_only',
action='store_true',
default=False,
help='Whether to save the evaluation results only')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -110,6 +115,7 @@ def main(): ...@@ -110,6 +115,7 @@ def main():
cfg['fleet'] = FLAGS.fleet cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
cfg['save_prediction_only'] = FLAGS.save_prediction_only
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu') place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册