validate.py 8.5 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yejianwu 已提交
15 16 17 18 19
import argparse
import sys
import os
import os.path
import numpy as np
L
liuqi 已提交
20
import re
Y
yejianwu 已提交
21
from scipy import spatial
22
from scipy import stats
Y
yejianwu 已提交
23 24 25

# Validation Flow:
# 1. Generate input data
26
# 2. Use mace_run to run model on phone.
Y
yejianwu 已提交
27 28
# 3. adb pull the result.
# 4. Compare output data of mace and tf
29
#    python validate.py --model_file tf_model_opt.pb \
Y
yejianwu 已提交
30
#        --input_file input_file \
31 32 33 34 35
#        --mace_out_file output_file \
#        --input_node input_node \
#        --output_node output_node \
#        --input_shape 1,64,64,3 \
#        --output_shape 1,64,64,2
Y
yejianwu 已提交
36

L
Liangliang He 已提交
37

Y
yejianwu 已提交
38
def load_data(file):
L
Liangliang He 已提交
39 40 41 42 43
    if os.path.isfile(file):
        return np.fromfile(file=file, dtype=np.float32)
    else:
        return np.empty([0])

Y
yejianwu 已提交
44

L
liuqi 已提交
45
def format_name(name):
L
Liangliang He 已提交
46 47
    return re.sub('[^0-9a-zA-Z]+', '_', name)

L
liuqi 已提交
48

49 50
def compare_output(platform, mace_runtime, output_name, mace_out_value,
                   out_value):
L
Liangliang He 已提交
51 52 53 54 55
    if mace_out_value.size != 0:
        out_value = out_value.reshape(-1)
        mace_out_value = mace_out_value.reshape(-1)
        assert len(out_value) == len(mace_out_value)
        similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
56
        print output_name, 'MACE VS', platform.upper(
L
Liangliang He 已提交
57
        ), 'similarity: ', similarity
58 59 60
        if (mace_runtime == "cpu" and similarity > 0.999) or \
            (mace_runtime == "gpu" and similarity > 0.995) or \
                (mace_runtime == "dsp" and similarity > 0.930):
L
Liangliang He 已提交
61 62 63 64
            print '===================Similarity Test Passed=================='
        else:
            print '===================Similarity Test Failed=================='
            sys.exit(-1)
Y
yejianwu 已提交
65
    else:
L
Liangliang He 已提交
66 67
        print '=======================Skip empty node==================='
        sys.exit(-1)
Y
yejianwu 已提交
68 69


70 71
def validate_tf_model(platform, mace_runtime, model_file, input_file,
                      mace_out_file, input_names, input_shapes, output_names):
L
Liangliang He 已提交
72
    import tensorflow as tf
73 74
    if not os.path.isfile(model_file):
        print("Input graph file '" + model_file + "' does not exist!")
L
Liangliang He 已提交
75 76
        sys.exit(-1)

77
    tf.reset_default_graph()
L
Liangliang He 已提交
78
    input_graph_def = tf.GraphDef()
79
    with open(model_file, "rb") as f:
L
Liangliang He 已提交
80 81
        data = f.read()
        input_graph_def.ParseFromString(data)
Y
yejianwu 已提交
82
        tf.import_graph_def(input_graph_def, name="")
L
Liangliang He 已提交
83 84 85 86 87 88 89

        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(input_graph_def, name="")
                input_dict = {}
                for i in range(len(input_names)):
                    input_value = load_data(
Y
yejianwu 已提交
90
                        input_file + "_" + format_name(input_names[i]))
L
Liangliang He 已提交
91 92 93 94 95 96 97 98 99 100 101
                    input_value = input_value.reshape(input_shapes[i])
                    input_node = graph.get_tensor_by_name(
                        input_names[i] + ':0')
                    input_dict[input_node] = input_value

                output_nodes = []
                for name in output_names:
                    output_nodes.extend(
                        [graph.get_tensor_by_name(name + ':0')])
                output_values = session.run(output_nodes, feed_dict=input_dict)
                for i in range(len(output_names)):
102
                    output_file_name = mace_out_file + "_" + \
L
liuqi 已提交
103
                            format_name(output_names[i])
L
Liangliang He 已提交
104
                    mace_out_value = load_data(output_file_name)
105 106
                    compare_output(platform, mace_runtime, output_names[i],
                                   mace_out_value, output_values[i])
L
Liangliang He 已提交
107 108


109 110 111
def validate_caffe_model(platform, mace_runtime, model_file, input_file,
                         mace_out_file, weight_file, input_names, input_shapes,
                         output_names, output_shapes):
L
Liangliang He 已提交
112 113
    os.environ['GLOG_minloglevel'] = '1'  # suprress Caffe verbose prints
    import caffe
114 115
    if not os.path.isfile(model_file):
        print("Input graph file '" + model_file + "' does not exist!")
L
Liangliang He 已提交
116
        sys.exit(-1)
117 118
    if not os.path.isfile(weight_file):
        print("Input weight file '" + weight_file + "' does not exist!")
L
Liangliang He 已提交
119 120 121 122
        sys.exit(-1)

    caffe.set_mode_cpu()

123
    net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
L
Liangliang He 已提交
124 125

    for i in range(len(input_names)):
Y
yejianwu 已提交
126
        input_value = load_data(input_file + "_" + format_name(input_names[i]))
L
Liangliang He 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
                                                                      2))
        input_blob_name = input_names[i]
        try:
            if input_names[i] in net.top_names:
                input_blob_name = net.top_names[input_names[i]][0]
        except ValueError:
            pass
        net.blobs[input_blob_name].data[0] = input_value

    net.forward()

    for i in range(len(output_names)):
        value = net.blobs[net.top_names[output_names[i]][0]].data
        out_shape = output_shapes[i]
        out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
            1], out_shape[2]
        value = value.reshape(out_shape).transpose((0, 2, 3, 1))
Y
yejianwu 已提交
145
        output_file_name = mace_out_file + "_" + format_name(
L
Liangliang He 已提交
146 147
            output_names[i])
        mace_out_value = load_data(output_file_name)
148 149
        compare_output(platform, mace_runtime, output_names[i], mace_out_value,
                       value)
L
Liangliang He 已提交
150

L
liuqi 已提交
151

152 153 154 155
def validate(platform, model_file, weight_file, input_file, mace_out_file,
             mace_runtime, input_shape, output_shape, input_node, output_node):
    input_names = [name for name in input_node.split(',')]
    input_shape_strs = [shape for shape in input_shape.split(':')]
L
Liangliang He 已提交
156 157
    input_shapes = [[int(x) for x in shape.split(',')]
                    for shape in input_shape_strs]
158
    output_names = [name for name in output_node.split(',')]
L
Liangliang He 已提交
159 160
    assert len(input_names) == len(input_shapes)

161 162 163 164 165 166
    if platform == 'tensorflow':
        validate_tf_model(platform, mace_runtime, model_file, input_file,
                          mace_out_file, input_names, input_shapes,
                          output_names)
    elif platform == 'caffe':
        output_shape_strs = [shape for shape in output_shape.split(':')]
L
Liangliang He 已提交
167 168
        output_shapes = [[int(x) for x in shape.split(',')]
                         for shape in output_shape_strs]
169 170 171
        validate_caffe_model(platform, mace_runtime, model_file, input_file,
                             mace_out_file, weight_file, input_names,
                             input_shapes, output_names, output_shapes)
L
Liangliang He 已提交
172

Y
yejianwu 已提交
173 174

def parse_args():
L
Liangliang He 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    """Parses command line arguments."""
    parser = argparse.ArgumentParser()
    parser.register("type", "bool", lambda v: v.lower() == "true")
    parser.add_argument(
        "--platform", type=str, default="", help="Tensorflow or Caffe.")
    parser.add_argument(
        "--model_file",
        type=str,
        default="",
        help="TensorFlow or Caffe \'GraphDef\' file to load.")
    parser.add_argument(
        "--weight_file",
        type=str,
        default="",
        help="caffe model file to load.")
    parser.add_argument(
        "--input_file", type=str, default="", help="input file.")
    parser.add_argument(
        "--mace_out_file",
        type=str,
        default="",
        help="mace output file to load.")
    parser.add_argument(
        "--mace_runtime", type=str, default="gpu", help="mace runtime device.")
    parser.add_argument(
        "--input_shape", type=str, default="1,64,64,3", help="input shape.")
    parser.add_argument(
        "--output_shape", type=str, default="1,64,64,2", help="output shape.")
    parser.add_argument(
        "--input_node", type=str, default="input_node", help="input node")
    parser.add_argument(
        "--output_node", type=str, default="output_node", help="output node")

    return parser.parse_known_args()
Y
yejianwu 已提交
209 210 211


if __name__ == '__main__':
L
Liangliang He 已提交
212
    FLAGS, unparsed = parse_args()
213 214 215 216 217 218 219 220 221 222
    validate(FLAGS.platform,
             FLAGS.model_file,
             FLAGS.weight_file,
             FLAGS.input_file,
             FLAGS.mace_out_file,
             FLAGS.mace_runtime,
             FLAGS.input_shape,
             FLAGS.output_shape,
             FLAGS.input_node,
             FLAGS.output_node)