diff --git a/x2paddle/onnx_infer.py b/x2paddle/onnx_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d37c431e4d2ffe09201b82830cafe04e028175b4 --- /dev/null +++ b/x2paddle/onnx_infer.py @@ -0,0 +1,44 @@ +import onnxruntime as rt +import os +import sys +import numpy as np +import onnx +import json +import argparse +from six import text_type as _text_type + + +def arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--save_dir", + "-s", + type=_text_type, + default=None, + help="define save_dir") + return parser + + +def main(): + parser = arg_parser() + args = parser.parse_args() + + save_dir = args.save_dir + model_dir = os.path.join(save_dir, 'onnx_model_infer.onnx') + data_dir = os.path.join(save_dir, 'input_data.npy') + + model = onnx.load(model_dir) + sess = rt.InferenceSession(model_dir) + + inputs = np.load(data_dir, allow_pickle=True) + data_dir + inputs_dict = {} + for i, ipt in enumerate(inputs): + inputs_dict[sess.get_inputs()[i].name] = ipt + + res = sess.run(None, input_feed=inputs_dict) + for idx, value_info in enumerate(model.graph.output): + np.save(os.path.join(save_dir, value_info.name), res[idx]) + + +if __name__ == "__main__": + main()