convert_data_pb.py 1.6 KB
Newer Older
M
bugfix  
Macrobull 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019

@author: Macrobull
"""

import os, sys
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper

from collections import OrderedDict as Dict
from glob import glob


def make_var_name(name):
    """
M
Macrobull 已提交
20 21
    make a valid variable name in Python code
    """
M
bugfix  
Macrobull 已提交
22

M
Macrobull 已提交
23 24
    assert name

M
Macrobull 已提交
25
    for s in ' \\|/:.-':
M
bugfix  
Macrobull 已提交
26 27 28
        name = name.replace(s, '_')
    if name.startswith('_'):
        name = 'var' + name
M
Macrobull 已提交
29 30
    elif name[0].isdigit():
        name = 'var_' + name
M
bugfix  
Macrobull 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    return name


data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(',')
output_names = sys.argv[3].split(',')
squeeze_data = len(sys.argv) > 4

# Load inputs
inputs = []
for fn in glob(os.path.join(data_dir, 'input_*.pb')):
    tensor = onnx.TensorProto()
    with open(fn, 'rb') as f:
        tensor.ParseFromString(f.read())
    tensor = numpy_helper.to_array(tensor)
    while squeeze_data and tensor.ndim > 4 and tensor.shape[0] == 1:
        tensor = tensor.squeeze(0)
    inputs.append(tensor)

# Load outputs
outputs = []
for fn in glob(os.path.join(data_dir, 'output_*.pb')):
    tensor = onnx.TensorProto()
    with open(fn, 'rb') as f:
        tensor.ParseFromString(f.read())
    tensor = numpy_helper.to_array(tensor)
    while squeeze_data and tensor.ndim > 2 and tensor.shape[0] == 1:
        tensor = tensor.squeeze(0)
    outputs.append(tensor)

inputs = Dict(zip(map(make_var_name, input_names), inputs))
outputs = Dict(zip(map(make_var_name, output_names), outputs))

np.savez(data_dir, inputs=inputs, outputs=outputs)