From b0a21b5f976e705732d3cf41c8c71b1a5b6f5d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 19 Sep 2017 18:02:59 +0800 Subject: [PATCH] Transform tensorflow graph to mace net --- .gitignore | 1 + WORKSPACE | 15 +++ mace/proto/BUILD | 14 +++ mace/python/tools/BUILD | 19 ++++ mace/python/tools/__init__.py | 0 mace/python/tools/tf_converter.py | 48 +++++++++ mace/python/tools/tf_converter_lib.py | 135 ++++++++++++++++++++++++++ mace/third_party/six.BUILD | 14 +++ 8 files changed, 246 insertions(+) create mode 100644 mace/python/tools/BUILD create mode 100644 mace/python/tools/__init__.py create mode 100644 mace/python/tools/tf_converter.py create mode 100644 mace/python/tools/tf_converter_lib.py create mode 100644 mace/third_party/six.BUILD diff --git a/.gitignore b/.gitignore index 482dabe6..d32ce9b1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ tags .idea/ cmake-build-debug/ *.sh +*.pyc diff --git a/WORKSPACE b/WORKSPACE index 247e552b..abd50963 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -27,6 +27,21 @@ new_http_archive( build_file = "mace/third_party/gtest.BUILD", ) +new_http_archive( + name = "six_archive", + urls = [ + "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + ], + sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", + strip_prefix = "six-1.10.0", + build_file = "mace/third_party/six.BUILD", +) +bind( + name = "six", + actual = "@six_archive//:six", +) + # Set up Android NDK android_ndk_repository( name = "androidndk", diff --git a/mace/proto/BUILD b/mace/proto/BUILD index c0ed8820..ed124858 100644 --- a/mace/proto/BUILD +++ b/mace/proto/BUILD @@ -1,10 +1,15 @@ # Description: # mace proto. # + package( default_visibility = ["//visibility:public"], ) +licenses(["notice"]) # Apache 2.0 + +load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") + proto_library( name = "proto", srcs = ["mace.proto"], @@ -14,3 +19,12 @@ cc_proto_library( name = "cc_proto", deps = [":proto"], ) + +py_proto_library( + name = "mace_py", + srcs = ["mace.proto"], + srcs_version = "PY2AND3", + deps = ["@com_google_protobuf//:protobuf_python"], + protoc = "@com_google_protobuf//:protoc", + default_runtime = "@com_google_protobuf//:protobuf_python", +) \ No newline at end of file diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD new file mode 100644 index 00000000..ab7af6c2 --- /dev/null +++ b/mace/python/tools/BUILD @@ -0,0 +1,19 @@ +py_library( + name = "tf_converter_lib", + srcs = ["tf_converter_lib.py"], + srcs_version = "PY2AND3", + deps = [ + "//mace/proto:mace_py", + ], +) + +py_binary( + name = "tf_converter", + srcs = ["tf_converter.py"], + srcs_version = "PY2AND3", + deps = [ + ":tf_converter_lib", + "@six_archive//:six", + ], +) + diff --git a/mace/python/tools/__init__.py b/mace/python/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py new file mode 100644 index 00000000..4797af7e --- /dev/null +++ b/mace/python/tools/tf_converter.py @@ -0,0 +1,48 @@ +import argparse +import sys +import tensorflow as tf +from tensorflow import gfile +from mace.python.tools import tf_converter_lib + +FLAGS = None + + +def main(unused_args): + if not gfile.Exists(FLAGS.input): + print("Input graph file '" + FLAGS.input + "' does not exist!") + return -1 + + input_graph_def = tf.GraphDef() + with gfile.Open(FLAGS.input, "rb") as f: + data = f.read() + input_graph_def.ParseFromString(data) + + output_graph_def = tf_converter_lib.convert_to_mace_pb( + input_graph_def) + + with gfile.GFile(FLAGS.output, "wb") as f: + f.write(output_graph_def.SerializeToString()) + with gfile.GFile(FLAGS.output + '_txt', "wb") as f: + f.write(str(output_graph_def)) + + +def parse_args(): + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--input", + type=str, + default="", + help="TensorFlow \'GraphDef\' file to load.") + parser.add_argument( + "--output", + type=str, + default="", + help="File to save the output graph to.") + return parser.parse_known_args() + + +if __name__ == '__main__': + FLAGS, unparsed = parse_args() + main(unused_args=[sys.argv[0]] + unparsed) diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py new file mode 100644 index 00000000..e119d03d --- /dev/null +++ b/mace/python/tools/tf_converter_lib.py @@ -0,0 +1,135 @@ +from mace.proto import mace_pb2 +import tensorflow as tf + +padding_mode = { + 'VALID': 0, + 'SAME': 1, + 'FULL': 2 +} +pooling_type_mode = { + 'AvgPool': 1, + 'MaxPool': 2 +} + + +def convert_ops(unresolved_ops, net_def): + ops_count = len(unresolved_ops) + resolved_count = 1 + + first_op = unresolved_ops[0] + + if first_op.type == 'Placeholder': + pass + elif first_op.type == 'Const': + tf_tensor = first_op.outputs[0].eval() + tensor = net_def.tensors.add() + tensor.name = first_op.outputs[0].name + tensor.dims.extend(tf_tensor.shape) + # TODO: support other type than float + tensor.data_type = mace_pb2.DT_FLOAT + tensor.float_data.extend(tf_tensor.astype(float).flat) + # net_def.tensors.extend([tensor]) + elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative': + op_def = net_def.op.add() + op_def.name = first_op.name + if first_op.type == 'DepthwiseConv2dNative': + op_def.type = 'DepthwiseConv2d' + else: + op_def.type = first_op.type + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode[first_op.get_attr('padding')] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(first_op.get_attr('strides')) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = first_op.get_attr('data_format') + + if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd': + bias_add_op = unresolved_ops[1] + op_def.input.extend([bias_add_op.inputs[1].name]) + resolved_count = 2 + elif first_op.type == 'Add' and first_op.name.endswith( + 'batchnorm/add') and ops_count > 7: + add_op = first_op + mul_op = unresolved_ops[2] + mul_1_op = unresolved_ops[3] + mul_2_op = unresolved_ops[4] + sub_op = unresolved_ops[5] + add_1_op = unresolved_ops[6] + # print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) + if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': + raise Exception('Invalid BatchNorm Op') + + input_name = mul_1_op.inputs[0].name + gamma = mul_op.inputs[1].name + beta = sub_op.inputs[0].name + mean = mul_2_op.inputs[0].name + variance = add_op.inputs[0].name + epsilon = add_op.inputs[1].name + + op_def = net_def.op.add() + op_def.name = first_op.name[:-4] # remove /add + op_def.type = 'BatchNorm' + op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) + op_def.output.extend([output.name for output in add_1_op.outputs]) + + resolved_count = 7 + elif first_op.type == 'Relu6': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = 'Relu' + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + max_limit_arg = op_def.arg.add() + max_limit_arg.name = 'max_limit' + max_limit_arg.f = 6 + elif first_op.type == 'Relu': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = first_op.type + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + elif first_op.type == 'AvgPool': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = 'Pooling' + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + pooling_type_arg = op_def.arg.add() + pooling_type_arg.name = 'pooling_type' + pooling_type_arg.i = pooling_type_mode[first_op.type] + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode[first_op.get_attr('padding')] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(first_op.get_attr('strides')[1:-1]) + strides_arg.name = 'kernels' + strides_arg.ints.extend(first_op.get_attr('ksize')[1:-1]) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = first_op.get_attr('data_format') + else: + raise Exception('Unknown Op: ' + first_op.name) + pass + + for i in range(resolved_count): + del unresolved_ops[0] + + +def convert_to_mace_pb(input_graph_def): + net_def = mace_pb2.NetDef() + + with tf.Session() as session: + with session.graph.as_default() as graph: + tf.import_graph_def(input_graph_def, name="") + ops = graph.get_operations() + unresolved_ops = ops + while len(unresolved_ops) > 0: + convert_ops(unresolved_ops, net_def) + + return net_def diff --git a/mace/third_party/six.BUILD b/mace/third_party/six.BUILD new file mode 100644 index 00000000..a1b2f7b2 --- /dev/null +++ b/mace/third_party/six.BUILD @@ -0,0 +1,14 @@ +# Description: +# Six provides simple utilities for wrapping over differences between Python 2 +# and Python 3. + +licenses(["notice"]) # MIT + +exports_files(["LICENSE"]) + +py_library( + name = "six", + srcs = ["six.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) -- GitLab