提交 b0a21b5f 编写于 作者: 李寅

Transform tensorflow graph to mace net

上级 5941a8a4
......@@ -3,3 +3,4 @@ tags
.idea/
cmake-build-debug/
*.sh
*.pyc
......@@ -27,6 +27,21 @@ new_http_archive(
build_file = "mace/third_party/gtest.BUILD",
)
new_http_archive(
name = "six_archive",
urls = [
"http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
"https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
],
sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
strip_prefix = "six-1.10.0",
build_file = "mace/third_party/six.BUILD",
)
bind(
name = "six",
actual = "@six_archive//:six",
)
# Set up Android NDK
android_ndk_repository(
name = "androidndk",
......
# Description:
# mace proto.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
proto_library(
name = "proto",
srcs = ["mace.proto"],
......@@ -14,3 +19,12 @@ cc_proto_library(
name = "cc_proto",
deps = [":proto"],
)
py_proto_library(
name = "mace_py",
srcs = ["mace.proto"],
srcs_version = "PY2AND3",
deps = ["@com_google_protobuf//:protobuf_python"],
protoc = "@com_google_protobuf//:protoc",
default_runtime = "@com_google_protobuf//:protobuf_python",
)
\ No newline at end of file
py_library(
name = "tf_converter_lib",
srcs = ["tf_converter_lib.py"],
srcs_version = "PY2AND3",
deps = [
"//mace/proto:mace_py",
],
)
py_binary(
name = "tf_converter",
srcs = ["tf_converter.py"],
srcs_version = "PY2AND3",
deps = [
":tf_converter_lib",
"@six_archive//:six",
],
)
import argparse
import sys
import tensorflow as tf
from tensorflow import gfile
from mace.python.tools import tf_converter_lib
FLAGS = None
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
f.write(str(output_graph_def))
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--output",
type=str,
default="",
help="File to save the output graph to.")
return parser.parse_known_args()
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed)
from mace.proto import mace_pb2
import tensorflow as tf
padding_mode = {
'VALID': 0,
'SAME': 1,
'FULL': 2
}
pooling_type_mode = {
'AvgPool': 1,
'MaxPool': 2
}
def convert_ops(unresolved_ops, net_def):
ops_count = len(unresolved_ops)
resolved_count = 1
first_op = unresolved_ops[0]
if first_op.type == 'Placeholder':
pass
elif first_op.type == 'Const':
tf_tensor = first_op.outputs[0].eval()
tensor = net_def.tensors.add()
tensor.name = first_op.outputs[0].name
tensor.dims.extend(tf_tensor.shape)
# TODO: support other type than float
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(tf_tensor.astype(float).flat)
# net_def.tensors.extend([tensor])
elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative':
op_def = net_def.op.add()
op_def.name = first_op.name
if first_op.type == 'DepthwiseConv2dNative':
op_def.type = 'DepthwiseConv2d'
else:
op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
padding_arg = op_def.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides'))
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format')
if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd':
bias_add_op = unresolved_ops[1]
op_def.input.extend([bias_add_op.inputs[1].name])
resolved_count = 2
elif first_op.type == 'Add' and first_op.name.endswith(
'batchnorm/add') and ops_count > 7:
add_op = first_op
mul_op = unresolved_ops[2]
mul_1_op = unresolved_ops[3]
mul_2_op = unresolved_ops[4]
sub_op = unresolved_ops[5]
add_1_op = unresolved_ops[6]
# print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type)
if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add':
raise Exception('Invalid BatchNorm Op')
input_name = mul_1_op.inputs[0].name
gamma = mul_op.inputs[1].name
beta = sub_op.inputs[0].name
mean = mul_2_op.inputs[0].name
variance = add_op.inputs[0].name
epsilon = add_op.inputs[1].name
op_def = net_def.op.add()
op_def.name = first_op.name[:-4] # remove /add
op_def.type = 'BatchNorm'
op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon])
op_def.output.extend([output.name for output in add_1_op.outputs])
resolved_count = 7
elif first_op.type == 'Relu6':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = 'Relu'
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
max_limit_arg = op_def.arg.add()
max_limit_arg.name = 'max_limit'
max_limit_arg.f = 6
elif first_op.type == 'Relu':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
elif first_op.type == 'AvgPool':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = 'Pooling'
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
pooling_type_arg = op_def.arg.add()
pooling_type_arg.name = 'pooling_type'
pooling_type_arg.i = pooling_type_mode[first_op.type]
padding_arg = op_def.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode[first_op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(first_op.get_attr('strides')[1:-1])
strides_arg.name = 'kernels'
strides_arg.ints.extend(first_op.get_attr('ksize')[1:-1])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = first_op.get_attr('data_format')
else:
raise Exception('Unknown Op: ' + first_op.name)
pass
for i in range(resolved_count):
del unresolved_ops[0]
def convert_to_mace_pb(input_graph_def):
net_def = mace_pb2.NetDef()
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
unresolved_ops = ops
while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, net_def)
return net_def
# Description:
# Six provides simple utilities for wrapping over differences between Python 2
# and Python 3.
licenses(["notice"]) # MIT
exports_files(["LICENSE"])
py_library(
name = "six",
srcs = ["six.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册