generate_tf2_models.py 1.8 KB
Newer Older
D
Dmitry Kurtaev 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
# This script is used to generate test data for OpenCV deep learning module.
import numpy as np
import tensorflow as tf
import shutil

assert(tf.__version__ >= '2.0.0')

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

np.random.seed(2701)

def gen_data(placeholder):
    shape = placeholder.shape.as_list()
    shape[0] = shape[0] if shape[0] else 1  # batch size = 1 instead None
    return np.random.standard_normal(shape).astype(placeholder.dtype.as_numpy_dtype())


def writeBlob(data, name):
    try:
        data = data.numpy()
    except:
        pass

    if data.ndim == 4:
        # NHWC->NCHW
        data = data.transpose(0, 3, 1, 2)
    elif data.ndim == 5:
        # NDHWC->NCDHW
        data = data.transpose(0, 4, 1, 2, 3)

    data = np.ascontiguousarray(data.astype(np.float32))
    np.save(name + '.npy', data)


def save(model, name, **kwargs):
    model.save(name)

    assert(len(kwargs) == 1)

    inputData = gen_data(next(iter(kwargs.values())))
    outputData = model(inputData)

    writeBlob(inputData, name + '_in')
    writeBlob(outputData, name + '_out')

    # Freeze model
    loaded = tf.saved_model.load(name)
    infer = loaded.signatures['serving_default']

    f = tf.function(infer).get_concrete_function(**kwargs)
    f2 = convert_variables_to_constants_v2(f)
    graph_def = f2.graph.as_graph_def()

    # print(graph_def)

    with tf.io.gfile.GFile(name + '_net.pb', 'wb') as f:
        f.write(graph_def.SerializeToString())

    shutil.rmtree(name)


# Test cases ###################################################################
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(1, 2, 3)),
  tf.keras.layers.Dense(3, activation='relu'),
])
save(model, 'tf2_dense', flatten_input=tf.TensorSpec(shape=[None, 1, 2, 3], dtype=tf.float32))