onnx_infer.py 1.4 KB
Newer Older
C
channingss 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import os
import sys
import numpy as np
import onnx
import json
import argparse
from six import text_type as _text_type


def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir",
                        "-s",
                        type=_text_type,
                        default=None,
                        help="define save_dir")
    return parser


def main():
21 22 23 24 25 26 27 28 29 30 31
    try:
        import onnxruntime as rt
        version = rt.__version__
        if version != '0.4.0':
            print("onnxruntime==0.4.0 is required")
            return
    except:
        print(
            "onnxruntime is not installed, use \"pip install onnxruntime==0.4.0\"."
        )
        return
C
channingss 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    parser = arg_parser()
    args = parser.parse_args()

    save_dir = args.save_dir
    model_dir = os.path.join(save_dir, 'onnx_model_infer.onnx')
    data_dir = os.path.join(save_dir, 'input_data.npy')

    model = onnx.load(model_dir)
    sess = rt.InferenceSession(model_dir)

    inputs = np.load(data_dir, allow_pickle=True)
    data_dir
    inputs_dict = {}
    for i, ipt in enumerate(inputs):
        inputs_dict[sess.get_inputs()[i].name] = ipt
    res = sess.run(None, input_feed=inputs_dict)
    for idx, value_info in enumerate(model.graph.output):
        np.save(os.path.join(save_dir, value_info.name), res[idx])

51

C
channingss 已提交
52 53
if __name__ == "__main__":
    main()