validate_icnet.py 3.5 KB
Newer Older
L
liuqi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
import argparse
import sys
import tensorflow as tf
import numpy as np

from tensorflow import gfile

# Validation Flow:
# 1. Generate input data
#    python validate_icnet.py --generate_data 1 \
#          --random_seed 1
# 2. Use mace_run to run icnet on phone.
# 3. adb pull the result.
# 4. Compare output data of mace and tf
#    python validate_icnet.py --model_file opt_icnet.pb \
#        --tf_input_file input_file \
#        --mace_out_file icnet.out


def generate_data(shape):
  np.random.seed(FLAGS.random_seed)
  data = np.random.random(shape)
  print FLAGS.tf_input_file
  data.astype(np.float32).tofile(FLAGS.tf_input_file)
  mace_data = np.transpose(data, axes=(2, 0, 1))
  mace_data.astype(np.float32).tofile(FLAGS.mace_input_file)
  print "Generate input file done."

def load_data(file):
  return np.fromfile(file=file, dtype=np.float32)

def valid_output(out_shape, mace_out_file, tf_out_value):
  mace_out_value = load_data(mace_out_file)
  mace_out_value = mace_out_value.reshape(out_shape)
  tf_out_data_t = np.transpose(tf_out_value, axes=(0, 3, 1, 2))
  res = np.allclose(mace_out_value, tf_out_data_t, rtol=0, atol=1e-5)
  print 'Passed! Haha' if res else 'Failed! Oops'


def run_model(input_shape):
  if not gfile.Exists(FLAGS.model_file):
    print("Input graph file '" + FLAGS.model_file + "' does not exist!")
    return -1

  input_graph_def = tf.GraphDef()
  with gfile.Open(FLAGS.model_file, "rb") as f:
    data = f.read()
    input_graph_def.ParseFromString(data)
    tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as session:
      with session.graph.as_default() as graph:
        tf.import_graph_def(input_graph_def, name="")
        input_node = graph.get_tensor_by_name('input_node:0')
        output_node = graph.get_tensor_by_name('output_node:0')

        input_value = load_data(FLAGS.tf_input_file)
        input_value = input_value.reshape(input_shape)
        
        output_value = session.run(output_node, feed_dict={input_node: [input_value]})
        return output_value

def main(unused_args):
  input_shape = [int(x) for x in FLAGS.input_shape.split(',')]
  output_shape = [int(x) for x in FLAGS.output_shape.split(',')]
  if FLAGS.generate_data:
    generate_data(input_shape)
  else:
    output_value = run_model(input_shape)
    valid_output(output_shape, FLAGS.mace_out_file, output_value)


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
    "--model_file",
    type=str,
    default="",
    help="TensorFlow \'GraphDef\' file to load.")
  parser.add_argument(
    "--tf_input_file",
    type=str,
    default="",
    help="tensorflow input data to load.")
  parser.add_argument(
    "--mace_input_file",
    type=str,
    default="",
    help="mace input data to load.")
  parser.add_argument(
    "--mace_out_file",
    type=str,
    default="",
    help="mace output file to load.")
  parser.add_argument(
    "--input_shape",
    type=str,
    default="480,480,3",
    help="input shape.")
  parser.add_argument(
    "--output_shape",
    type=str,
    default="1,2,480,480",
    help="output shape.")
  parser.add_argument(
    "--generate_data",
    type='bool',
    default="false",
    help="Random seed for generate test case.")
  parser.add_argument(
    "--random_seed",
    type=int,
    default="0",
    help="Random seed for generate test case.")

  return parser.parse_known_args()


if __name__ == '__main__':
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)