diff --git a/tools/infer/py_infer.py b/tools/infer/py_infer.py index b566092bc1fa069e6ffd4b051a5de78701dad9ea..8236bc6f96d0c547d0ef685315df1ce5af58c628 100644 --- a/tools/infer/py_infer.py +++ b/tools/infer/py_infer.py @@ -87,6 +87,7 @@ def main(): exe, program, feed_names, fetch_lists = create_predictor(args) data = preprocess(args.image_file, operators) + data = np.expand_dims(data, axis=0) outputs = exe.run(program, feed={feed_names[0]: data}, fetch_list=fetch_lists,