convert_data_npz_0.py 1.1 KB
Newer Older
M
Macrobull 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019

@author: Macrobull
"""

import sys
import numpy as np

from collections import OrderedDict as Dict

M
Macrobull 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

def _make_var_name(name):
    """
    make a valid variable name in Python code
    """

    if name == '':
        return '_'
    if name[0].isdigit():
        return 'var_' + name
    for s in ' *?\\/-:':
        name = name.replace(s, '_')
    if name.startswith('_'):
        name = 'var' + name
    return name


M
Macrobull 已提交
31 32 33
fn = sys.argv[1]
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
M
Macrobull 已提交
34
squeeze_data = len(sys.argv) > 4
M
Macrobull 已提交
35

M
Macrobull 已提交
36
data = np.load(fn, encoding='bytes')
M
Macrobull 已提交
37 38 39
input_data = data['inputs']
output_data = data['outputs']

M
Macrobull 已提交
40 41 42 43 44
while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
    input_data = input_data.squeeze(0)
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
    output_data = output_data.squeeze(0)

M
Macrobull 已提交
45 46
inputs = Dict(zip(map(_make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data]))
M
Macrobull 已提交
47

M
Macrobull 已提交
48
np.savez(fn, inputs=inputs, outputs=outputs)  # overwrite