未验证 提交 258bc64a 编写于 作者: K Kaipeng Deng 提交者: GitHub

add save_inference_model in infer.py (#2714)

* add save_inference_model in infer.py

* format code

* add comment

* add save_inference_model doc

* refine doc
上级 7225e149
......@@ -72,6 +72,18 @@ python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_dir=demo
The visualization files are saved in `output` by default, to specify a different
path, simply add a `--save_file=` flag.
- Save inference model
```bash
export CUDA_VISIBLE_DEVICES=0
# or run on CPU with:
# export CPU_NUM=1
python tools/infer.py -c configs/faster_rcnn_r50_1x.yml --infer_img=demo/000000570688.jpg \
--save_inference_model
```
Save inference model by set `--save_inference_model`.
## FAQ
......
......@@ -81,6 +81,24 @@ def get_test_images(infer_dir, infer_img):
return images
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()]
# im_id is only used for visualize, not used in inference model
feeded_var_names.remove('im_id')
target_vars = test_fetches.values()
logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars]))
fluid.io.save_inference_model(save_dir,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=exe,
main_program=infer_prog,
params_filename="__parmas__")
def main():
cfg = load_config(FLAGS.config)
......@@ -119,6 +137,9 @@ def main():
if cfg.weights:
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
if FLAGS.save_inference_model:
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
# parse infer fetches
extra_keys = []
if cfg['metric'] == 'COCO':
......@@ -196,5 +217,10 @@ if __name__ == '__main__':
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--save_inference_model",
action='store_true',
default=False,
help="Save inference model in output_dir if True.")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册