未验证 提交 53c2b8b6 编写于 作者: W Wenyu 提交者: GitHub

Save infer results to file (#5792)

* save infer results to file.

* update, for mac temp file

* update func name, fix mem, fix dir
上级 7c5736b2
...@@ -82,3 +82,7 @@ ppdet/version.py ...@@ -82,3 +82,7 @@ ppdet/version.py
# NPU meta folder # NPU meta folder
kernel_meta/ kernel_meta/
# MAC
*.DS_Store
...@@ -53,7 +53,7 @@ list below can be viewed by `--help` ...@@ -53,7 +53,7 @@ list below can be viewed by `--help`
| --draw_threshold | infer | Threshold to reserve the result for visualization | 0.5 | such as `--draw_threshold 0.7` | | --draw_threshold | infer | Threshold to reserve the result for visualization | 0.5 | such as `--draw_threshold 0.7` |
| --infer_dir | infer | Directory for images to perform inference on | None | One of `infer_dir` and `infer_img` is requied | | --infer_dir | infer | Directory for images to perform inference on | None | One of `infer_dir` and `infer_img` is requied |
| --infer_img | infer | Image path | None | One of `infer_dir` and `infer_img` is requied, `infer_img` has higher priority over `infer_dir` | | --infer_img | infer | Image path | None | One of `infer_dir` and `infer_img` is requied, `infer_img` has higher priority over `infer_dir` |
| --save_results | infer | Whether to save detection results to file | False | Optional
......
...@@ -215,7 +215,7 @@ visualdl --logdir vdl_dir/scalar/ ...@@ -215,7 +215,7 @@ visualdl --logdir vdl_dir/scalar/
| --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 例如`--draw_threshold=0.7` | | --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 例如`--draw_threshold=0.7` |
| --infer_dir | infer | 用于预测的图片文件夹路径 | None | `--infer_img``--infer_dir`必须至少设置一个 | | --infer_dir | infer | 用于预测的图片文件夹路径 | None | `--infer_img``--infer_dir`必须至少设置一个 |
| --infer_img | infer | 用于预测的图片路径 | None | `--infer_img``--infer_dir`必须至少设置一个,`infer_img`具有更高优先级 | | --infer_img | infer | 用于预测的图片路径 | None | `--infer_img``--infer_dir`必须至少设置一个,`infer_img`具有更高优先级 |
| --save_txt | infer | 是否在文件夹下将图片的预测结果保存到文本文件中 | False | 可选 | | --save_results | infer | 是否在文件夹下将图片的预测结果保存到文件中 | False | 可选 |
## 8 模型导出 ## 8 模型导出
......
...@@ -25,6 +25,7 @@ from tqdm import tqdm ...@@ -25,6 +25,7 @@ from tqdm import tqdm
import numpy as np import numpy as np
import typing import typing
from PIL import Image, ImageOps, ImageFile from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
import paddle import paddle
...@@ -552,10 +553,45 @@ class Trainer(object): ...@@ -552,10 +553,45 @@ class Trainer(object):
images, images,
draw_threshold=0.5, draw_threshold=0.5,
output_dir='output', output_dir='output',
save_txt=False): save_results=False):
self.dataset.set_images(images) self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0) loader = create('TestReader')(self.dataset, 0)
def setup_metrics_for_loader():
# mem
metrics = copy.deepcopy(self._metrics)
mode = self.mode
save_prediction_only = self.cfg[
'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
output_eval = self.cfg[
'output_eval'] if 'output_eval' in self.cfg else None
# modify
self.mode = '_test'
self.cfg['save_prediction_only'] = True
self.cfg['output_eval'] = output_dir
self._init_metrics()
# restore
self.mode = mode
self.cfg.pop('save_prediction_only')
if save_prediction_only is not None:
self.cfg['save_prediction_only'] = save_prediction_only
self.cfg.pop('output_eval')
if output_eval is not None:
self.cfg['output_eval'] = output_eval
_metrics = copy.deepcopy(self._metrics)
self._metrics = metrics
return _metrics
if save_results:
metrics = setup_metrics_for_loader()
else:
metrics = []
imid2path = self.dataset.get_imid2path() imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno() anno_file = self.dataset.get_anno()
...@@ -574,6 +610,9 @@ class Trainer(object): ...@@ -574,6 +610,9 @@ class Trainer(object):
# forward # forward
outs = self.model(data) outs = self.model(data)
for _m in metrics:
_m.update(data, outs)
for key in ['im_shape', 'scale_factor', 'im_id']: for key in ['im_shape', 'scale_factor', 'im_id']:
if isinstance(data, typing.Sequence): if isinstance(data, typing.Sequence):
outs[key] = data[0][key] outs[key] = data[0][key]
...@@ -583,11 +622,16 @@ class Trainer(object): ...@@ -583,11 +622,16 @@ class Trainer(object):
if hasattr(value, 'numpy'): if hasattr(value, 'numpy'):
outs[key] = value.numpy() outs[key] = value.numpy()
results.append(outs) results.append(outs)
# sniper # sniper
if type(self.dataset) == SniperCOCODataSet: if type(self.dataset) == SniperCOCODataSet:
results = self.dataset.anno_cropper.aggregate_chips_detections( results = self.dataset.anno_cropper.aggregate_chips_detections(
results) results)
for _m in metrics:
_m.accumulate()
_m.reset()
for outs in results: for outs in results:
batch_res = get_infer_results(outs, clsid2catid) batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num'] bbox_num = outs['bbox_num']
...@@ -619,15 +663,7 @@ class Trainer(object): ...@@ -619,15 +663,7 @@ class Trainer(object):
logger.info("Detection bbox results save in {}".format( logger.info("Detection bbox results save in {}".format(
save_name)) save_name))
image.save(save_name, quality=95) image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
results = {}
results["im_id"] = im_id
if bbox_res:
results["bbox_res"] = bbox_res
if keypoint_res:
results["keypoint_res"] = keypoint_res
save_result(save_path, results, catid2name, draw_threshold)
start = end start = end
def _get_save_image_name(self, output_dir, image_path): def _get_save_image_name(self, output_dir, image_path):
......
...@@ -22,6 +22,7 @@ import json ...@@ -22,6 +22,7 @@ import json
import paddle import paddle
import numpy as np import numpy as np
import typing import typing
from pathlib import Path
from .map_utils import prune_zero_padding, DetectionMAP from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval from .coco_utils import get_infer_results, cocoapi_eval
...@@ -32,13 +33,8 @@ from ppdet.utils.logger import setup_logger ...@@ -32,13 +33,8 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
__all__ = [ __all__ = [
'Metric', 'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results',
'COCOMetric', 'RBoxMetric', 'SNIPERCOCOMetric'
'VOCMetric',
'WiderFaceMetric',
'get_infer_results',
'RBoxMetric',
'SNIPERCOCOMetric'
] ]
COCO_SIGMAS = np.array([ COCO_SIGMAS = np.array([
...@@ -74,8 +70,6 @@ class Metric(paddle.metric.Metric): ...@@ -74,8 +70,6 @@ class Metric(paddle.metric.Metric):
class COCOMetric(Metric): class COCOMetric(Metric):
def __init__(self, anno_file, **kwargs): def __init__(self, anno_file, **kwargs):
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
self.anno_file = anno_file self.anno_file = anno_file
self.clsid2catid = kwargs.get('clsid2catid', None) self.clsid2catid = kwargs.get('clsid2catid', None)
if self.clsid2catid is None: if self.clsid2catid is None:
...@@ -86,6 +80,14 @@ class COCOMetric(Metric): ...@@ -86,6 +80,14 @@ class COCOMetric(Metric):
self.bias = kwargs.get('bias', 0) self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False) self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.iou_type = kwargs.get('IouType', 'bbox') self.iou_type = kwargs.get('IouType', 'bbox')
if not self.save_prediction_only:
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
if self.output_eval is not None:
Path(self.output_eval).mkdir(exist_ok=True)
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -427,11 +429,13 @@ class SNIPERCOCOMetric(COCOMetric): ...@@ -427,11 +429,13 @@ class SNIPERCOCOMetric(COCOMetric):
self.chip_results.append(outs) self.chip_results.append(outs)
def accumulate(self): def accumulate(self):
results = self.dataset.anno_cropper.aggregate_chips_detections(self.chip_results) results = self.dataset.anno_cropper.aggregate_chips_detections(
self.chip_results)
for outs in results: for outs in results:
infer_results = get_infer_results(outs, self.clsid2catid, bias=self.bias) infer_results = get_infer_results(
self.results['bbox'] += infer_results['bbox'] if 'bbox' in infer_results else [] outs, self.clsid2catid, bias=self.bias)
self.results['bbox'] += infer_results[
'bbox'] if 'bbox' in infer_results else []
super(SNIPERCOCOMetric, self).accumulate() super(SNIPERCOCOMetric, self).accumulate()
...@@ -77,10 +77,10 @@ def parse_args(): ...@@ -77,10 +77,10 @@ def parse_args():
default="vdl_log_dir/image", default="vdl_log_dir/image",
help='VisualDL logging directory for image.') help='VisualDL logging directory for image.')
parser.add_argument( parser.add_argument(
"--save_txt", "--save_results",
type=bool, type=bool,
default=False, default=False,
help="Whether to save inference result in txt.") help="Whether to save inference results to output_dir.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -131,7 +131,7 @@ def run(FLAGS, cfg): ...@@ -131,7 +131,7 @@ def run(FLAGS, cfg):
images, images,
draw_threshold=FLAGS.draw_threshold, draw_threshold=FLAGS.draw_threshold,
output_dir=FLAGS.output_dir, output_dir=FLAGS.output_dir,
save_txt=FLAGS.save_txt) save_results=FLAGS.save_results)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册