validate.py 9.0 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
yejianwu 已提交
15 16 17 18 19
import argparse
import sys
import os
import os.path
import numpy as np
L
liuqi 已提交
20
import re
Y
yejianwu 已提交
21
from scipy import spatial
22
from scipy import stats
Y
yejianwu 已提交
23

L
liuqi 已提交
24 25
import common

Y
yejianwu 已提交
26 27
# Validation Flow:
# 1. Generate input data
28
# 2. Use mace_run to run model on phone.
Y
yejianwu 已提交
29 30
# 3. adb pull the result.
# 4. Compare output data of mace and tf
31
#    python validate.py --model_file tf_model_opt.pb \
Y
yejianwu 已提交
32
#        --input_file input_file \
33 34 35 36 37
#        --mace_out_file output_file \
#        --input_node input_node \
#        --output_node output_node \
#        --input_shape 1,64,64,3 \
#        --output_shape 1,64,64,2
Y
yejianwu 已提交
38

39 40
VALIDATION_MODULE = 'VALIDATION'

L
Liangliang He 已提交
41

Y
yejianwu 已提交
42
def load_data(file):
L
Liangliang He 已提交
43 44 45 46 47
    if os.path.isfile(file):
        return np.fromfile(file=file, dtype=np.float32)
    else:
        return np.empty([0])

Y
yejianwu 已提交
48

49
def compare_output(platform, device_type, output_name, mace_out_value,
50
                   out_value):
L
Liangliang He 已提交
51 52 53 54 55
    if mace_out_value.size != 0:
        out_value = out_value.reshape(-1)
        mace_out_value = mace_out_value.reshape(-1)
        assert len(out_value) == len(mace_out_value)
        similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
56 57 58
        common.MaceLogger.summary(
            output_name + ' MACE VS ' + platform.upper()
            + ' similarity: ' + str(similarity))
59 60 61
        if (device_type == "CPU" and similarity > 0.999) or \
            (device_type == "GPU" and similarity > 0.995) or \
                (device_type == "HEXAGON" and similarity > 0.930):
62 63
            common.MaceLogger.summary(
                common.StringFormatter.block("Similarity Test Passed"))
L
Liangliang He 已提交
64
        else:
65 66
            common.MaceLogger.error(
                "", common.StringFormatter.block("Similarity Test Failed"))
Y
yejianwu 已提交
67
    else:
68 69 70
        common.MaceLogger.error(
            "", common.StringFormatter.block(
                "Similarity Test failed because of empty output"))
Y
yejianwu 已提交
71 72


李寅 已提交
73 74 75 76 77 78 79
def normalize_tf_tensor_name(name):
    if name.find(':') == -1:
        return name + ':0'
    else:
        return name


80
def validate_tf_model(platform, device_type, model_file, input_file,
81
                      mace_out_file, input_names, input_shapes, output_names):
L
Liangliang He 已提交
82
    import tensorflow as tf
83
    if not os.path.isfile(model_file):
84 85 86
        common.MaceLogger.error(
            VALIDATION_MODULE,
            "Input graph file '" + model_file + "' does not exist!")
L
Liangliang He 已提交
87

88
    tf.reset_default_graph()
L
Liangliang He 已提交
89
    input_graph_def = tf.GraphDef()
90
    with open(model_file, "rb") as f:
L
Liangliang He 已提交
91 92
        data = f.read()
        input_graph_def.ParseFromString(data)
Y
yejianwu 已提交
93
        tf.import_graph_def(input_graph_def, name="")
L
Liangliang He 已提交
94 95 96 97 98 99 100

        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(input_graph_def, name="")
                input_dict = {}
                for i in range(len(input_names)):
                    input_value = load_data(
L
liuqi 已提交
101
                        common.formatted_file_name(input_file, input_names[i]))
L
Liangliang He 已提交
102 103
                    input_value = input_value.reshape(input_shapes[i])
                    input_node = graph.get_tensor_by_name(
李寅 已提交
104
                        normalize_tf_tensor_name(input_names[i]))
L
Liangliang He 已提交
105 106 107 108 109
                    input_dict[input_node] = input_value

                output_nodes = []
                for name in output_names:
                    output_nodes.extend(
李寅 已提交
110 111
                        [graph.get_tensor_by_name(
                            normalize_tf_tensor_name(name))])
L
Liangliang He 已提交
112 113
                output_values = session.run(output_nodes, feed_dict=input_dict)
                for i in range(len(output_names)):
L
liuqi 已提交
114 115
                    output_file_name = common.formatted_file_name(
                        mace_out_file, output_names[i])
L
Liangliang He 已提交
116
                    mace_out_value = load_data(output_file_name)
117
                    compare_output(platform, device_type, output_names[i],
118
                                   mace_out_value, output_values[i])
L
Liangliang He 已提交
119 120


121
def validate_caffe_model(platform, device_type, model_file, input_file,
122 123
                         mace_out_file, weight_file, input_names, input_shapes,
                         output_names, output_shapes):
L
Liangliang He 已提交
124 125
    os.environ['GLOG_minloglevel'] = '1'  # suprress Caffe verbose prints
    import caffe
126
    if not os.path.isfile(model_file):
127 128 129
        common.MaceLogger.error(
            VALIDATION_MODULE,
            "Input graph file '" + model_file + "' does not exist!")
130
    if not os.path.isfile(weight_file):
131 132 133
        common.MaceLogger.error(
            VALIDATION_MODULE,
            "Input weight file '" + weight_file + "' does not exist!")
L
Liangliang He 已提交
134 135 136

    caffe.set_mode_cpu()

137
    net = caffe.Net(model_file, caffe.TEST, weights=weight_file)
L
Liangliang He 已提交
138 139

    for i in range(len(input_names)):
L
liuqi 已提交
140 141
        input_value = load_data(
            common.formatted_file_name(input_file, input_names[i]))
L
Liangliang He 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
                                                                      2))
        input_blob_name = input_names[i]
        try:
            if input_names[i] in net.top_names:
                input_blob_name = net.top_names[input_names[i]][0]
        except ValueError:
            pass
        net.blobs[input_blob_name].data[0] = input_value

    net.forward()

    for i in range(len(output_names)):
        value = net.blobs[net.top_names[output_names[i]][0]].data
        out_shape = output_shapes[i]
L
liuqi 已提交
157 158 159 160
        if len(out_shape) == 4:
            out_shape[1], out_shape[2], out_shape[3] = \
                out_shape[3], out_shape[1], out_shape[2]
            value = value.reshape(out_shape).transpose((0, 2, 3, 1))
L
liuqi 已提交
161 162
        output_file_name = common.formatted_file_name(
            mace_out_file, output_names[i])
L
Liangliang He 已提交
163
        mace_out_value = load_data(output_file_name)
164
        compare_output(platform, device_type, output_names[i], mace_out_value,
165
                       value)
L
Liangliang He 已提交
166

L
liuqi 已提交
167

168
def validate(platform, model_file, weight_file, input_file, mace_out_file,
169
             device_type, input_shape, output_shape, input_node, output_node):
170 171
    input_names = [name for name in input_node.split(',')]
    input_shape_strs = [shape for shape in input_shape.split(':')]
L
Liangliang He 已提交
172 173
    input_shapes = [[int(x) for x in shape.split(',')]
                    for shape in input_shape_strs]
174
    output_names = [name for name in output_node.split(',')]
L
Liangliang He 已提交
175 176
    assert len(input_names) == len(input_shapes)

177
    if platform == 'tensorflow':
178
        validate_tf_model(platform, device_type, model_file, input_file,
179 180 181 182
                          mace_out_file, input_names, input_shapes,
                          output_names)
    elif platform == 'caffe':
        output_shape_strs = [shape for shape in output_shape.split(':')]
L
Liangliang He 已提交
183 184
        output_shapes = [[int(x) for x in shape.split(',')]
                         for shape in output_shape_strs]
185
        validate_caffe_model(platform, device_type, model_file, input_file,
186 187
                             mace_out_file, weight_file, input_names,
                             input_shapes, output_names, output_shapes)
L
Liangliang He 已提交
188

Y
yejianwu 已提交
189 190

def parse_args():
L
Liangliang He 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    """Parses command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--platform", type=str, default="", help="Tensorflow or Caffe.")
    parser.add_argument(
        "--model_file",
        type=str,
        default="",
        help="TensorFlow or Caffe \'GraphDef\' file to load.")
    parser.add_argument(
        "--weight_file",
        type=str,
        default="",
        help="caffe model file to load.")
    parser.add_argument(
        "--input_file", type=str, default="", help="input file.")
    parser.add_argument(
        "--mace_out_file",
        type=str,
        default="",
        help="mace output file to load.")
    parser.add_argument(
213
        "--device_type", type=str, default="", help="mace runtime device.")
L
Liangliang He 已提交
214 215 216 217 218 219 220 221 222 223
    parser.add_argument(
        "--input_shape", type=str, default="1,64,64,3", help="input shape.")
    parser.add_argument(
        "--output_shape", type=str, default="1,64,64,2", help="output shape.")
    parser.add_argument(
        "--input_node", type=str, default="input_node", help="input node")
    parser.add_argument(
        "--output_node", type=str, default="output_node", help="output node")

    return parser.parse_known_args()
Y
yejianwu 已提交
224 225 226


if __name__ == '__main__':
L
Liangliang He 已提交
227
    FLAGS, unparsed = parse_args()
228 229 230 231 232
    validate(FLAGS.platform,
             FLAGS.model_file,
             FLAGS.weight_file,
             FLAGS.input_file,
             FLAGS.mace_out_file,
233
             FLAGS.device_type,
234 235 236 237
             FLAGS.input_shape,
             FLAGS.output_shape,
             FLAGS.input_node,
             FLAGS.output_node)