未验证 提交 53c2b8b6 编写于 作者: W Wenyu 提交者: GitHub

Save infer results to file (#5792)

* save infer results to file.

* update, for mac temp file

* update func name, fix mem, fix dir
上级 7c5736b2
......@@ -82,3 +82,7 @@ ppdet/version.py
# NPU meta folder
kernel_meta/
# MAC
*.DS_Store
......@@ -53,7 +53,7 @@ list below can be viewed by `--help`
| --draw_threshold | infer | Threshold to reserve the result for visualization | 0.5 | such as `--draw_threshold 0.7` |
| --infer_dir | infer | Directory for images to perform inference on | None | One of `infer_dir` and `infer_img` is requied |
| --infer_img | infer | Image path | None | One of `infer_dir` and `infer_img` is requied, `infer_img` has higher priority over `infer_dir` |
| --save_results | infer | Whether to save detection results to file | False | Optional
......
......@@ -215,7 +215,7 @@ visualdl --logdir vdl_dir/scalar/
| --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 例如`--draw_threshold=0.7` |
| --infer_dir | infer | 用于预测的图片文件夹路径 | None | `--infer_img``--infer_dir`必须至少设置一个 |
| --infer_img | infer | 用于预测的图片路径 | None | `--infer_img``--infer_dir`必须至少设置一个,`infer_img`具有更高优先级 |
| --save_txt | infer | 是否在文件夹下将图片的预测结果保存到文本文件中 | False | 可选 |
| --save_results | infer | 是否在文件夹下将图片的预测结果保存到文件中 | False | 可选 |
## 8 模型导出
......
......@@ -25,6 +25,7 @@ from tqdm import tqdm
import numpy as np
import typing
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import paddle
......@@ -552,10 +553,45 @@ class Trainer(object):
images,
draw_threshold=0.5,
output_dir='output',
save_txt=False):
save_results=False):
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
def setup_metrics_for_loader():
# mem
metrics = copy.deepcopy(self._metrics)
mode = self.mode
save_prediction_only = self.cfg[
'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
output_eval = self.cfg[
'output_eval'] if 'output_eval' in self.cfg else None
# modify
self.mode = '_test'
self.cfg['save_prediction_only'] = True
self.cfg['output_eval'] = output_dir
self._init_metrics()
# restore
self.mode = mode
self.cfg.pop('save_prediction_only')
if save_prediction_only is not None:
self.cfg['save_prediction_only'] = save_prediction_only
self.cfg.pop('output_eval')
if output_eval is not None:
self.cfg['output_eval'] = output_eval
_metrics = copy.deepcopy(self._metrics)
self._metrics = metrics
return _metrics
if save_results:
metrics = setup_metrics_for_loader()
else:
metrics = []
imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno()
......@@ -574,6 +610,9 @@ class Trainer(object):
# forward
outs = self.model(data)
for _m in metrics:
_m.update(data, outs)
for key in ['im_shape', 'scale_factor', 'im_id']:
if isinstance(data, typing.Sequence):
outs[key] = data[0][key]
......@@ -583,11 +622,16 @@ class Trainer(object):
if hasattr(value, 'numpy'):
outs[key] = value.numpy()
results.append(outs)
# sniper
if type(self.dataset) == SniperCOCODataSet:
results = self.dataset.anno_cropper.aggregate_chips_detections(
results)
for _m in metrics:
_m.accumulate()
_m.reset()
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
......@@ -619,15 +663,7 @@ class Trainer(object):
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
results = {}
results["im_id"] = im_id
if bbox_res:
results["bbox_res"] = bbox_res
if keypoint_res:
results["keypoint_res"] = keypoint_res
save_result(save_path, results, catid2name, draw_threshold)
start = end
def _get_save_image_name(self, output_dir, image_path):
......
......@@ -22,6 +22,7 @@ import json
import paddle
import numpy as np
import typing
from pathlib import Path
from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval
......@@ -32,13 +33,8 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'Metric',
'COCOMetric',
'VOCMetric',
'WiderFaceMetric',
'get_infer_results',
'RBoxMetric',
'SNIPERCOCOMetric'
'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results',
'RBoxMetric', 'SNIPERCOCOMetric'
]
COCO_SIGMAS = np.array([
......@@ -74,8 +70,6 @@ class Metric(paddle.metric.Metric):
class COCOMetric(Metric):
def __init__(self, anno_file, **kwargs):
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
self.anno_file = anno_file
self.clsid2catid = kwargs.get('clsid2catid', None)
if self.clsid2catid is None:
......@@ -86,6 +80,14 @@ class COCOMetric(Metric):
self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.iou_type = kwargs.get('IouType', 'bbox')
if not self.save_prediction_only:
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
if self.output_eval is not None:
Path(self.output_eval).mkdir(exist_ok=True)
self.reset()
def reset(self):
......@@ -427,11 +429,13 @@ class SNIPERCOCOMetric(COCOMetric):
self.chip_results.append(outs)
def accumulate(self):
results = self.dataset.anno_cropper.aggregate_chips_detections(self.chip_results)
results = self.dataset.anno_cropper.aggregate_chips_detections(
self.chip_results)
for outs in results:
infer_results = get_infer_results(outs, self.clsid2catid, bias=self.bias)
self.results['bbox'] += infer_results['bbox'] if 'bbox' in infer_results else []
infer_results = get_infer_results(
outs, self.clsid2catid, bias=self.bias)
self.results['bbox'] += infer_results[
'bbox'] if 'bbox' in infer_results else []
super(SNIPERCOCOMetric, self).accumulate()
......@@ -77,10 +77,10 @@ def parse_args():
default="vdl_log_dir/image",
help='VisualDL logging directory for image.')
parser.add_argument(
"--save_txt",
"--save_results",
type=bool,
default=False,
help="Whether to save inference result in txt.")
help="Whether to save inference results to output_dir.")
args = parser.parse_args()
return args
......@@ -131,7 +131,7 @@ def run(FLAGS, cfg):
images,
draw_threshold=FLAGS.draw_threshold,
output_dir=FLAGS.output_dir,
save_txt=FLAGS.save_txt)
save_results=FLAGS.save_results)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册