提交 5b7635f6 编写于 作者: L Liangliang He

Merge branch 'caffe' into 'master'

Support caffe model

See merge request !45
...@@ -18,3 +18,12 @@ py_proto_library( ...@@ -18,3 +18,12 @@ py_proto_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = ["@com_google_protobuf//:protobuf_python"], deps = ["@com_google_protobuf//:protobuf_python"],
) )
py_proto_library(
name = "caffe_py",
srcs = ["caffe.proto"],
default_runtime = "@com_google_protobuf//:protobuf_python",
protoc = "@com_google_protobuf//:protoc",
srcs_version = "PY2AND3",
deps = ["@com_google_protobuf//:protobuf_python"],
)
此差异已折叠。
...@@ -13,6 +13,18 @@ py_library( ...@@ -13,6 +13,18 @@ py_library(
], ],
) )
py_library(
name = "caffe_converter_lib",
srcs = [
"caffe_converter_lib.py",
],
srcs_version = "PY2AND3",
deps = [
":memory_optimizer",
"//lib/proto:caffe_py",
],
)
py_library( py_library(
name = "source_converter_lib", name = "source_converter_lib",
srcs = [ srcs = [
...@@ -25,11 +37,12 @@ py_library( ...@@ -25,11 +37,12 @@ py_library(
) )
py_binary( py_binary(
name = "tf_converter", name = "converter",
srcs = ["tf_converter.py"], srcs = ["converter.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":tf_converter_lib", ":tf_converter_lib",
":caffe_converter_lib",
":source_converter_lib", ":source_converter_lib",
"@six_archive//:six", "@six_archive//:six",
], ],
......
此差异已折叠。
import argparse import argparse
import sys import sys
import hashlib import hashlib
import tensorflow as tf import os.path
from tensorflow import gfile
from lib.proto import mace_pb2
from lib.python.tools import tf_converter_lib
from lib.python.tools import tf_dsp_converter_lib
from lib.python.tools import source_converter_lib from lib.python.tools import source_converter_lib
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3 # ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS = None FLAGS = None
...@@ -20,38 +16,57 @@ def file_checksum(fname): ...@@ -20,38 +16,57 @@ def file_checksum(fname):
return hash_func.hexdigest() return hash_func.hexdigest()
def main(unused_args): def main(unused_args):
if not gfile.Exists(FLAGS.input): if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.input + "' does not exist!") print("Input graph file '" + FLAGS.model_file + "' does not exist!")
return -1 return -1
model_checksum = file_checksum(FLAGS.input) model_checksum = file_checksum(FLAGS.model_file)
if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum: if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum)) print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum))
return -1 return -1
input_graph_def = tf.GraphDef() if FLAGS.platform == 'caffe':
with gfile.Open(FLAGS.input, "rb") as f: if not os.path.isfile(FLAGS.weight_file):
data = f.read() print("Input weight file '" + FLAGS.weight_file + "' does not exist!")
input_graph_def.ParseFromString(data) return -1
weight_checksum = file_checksum(FLAGS.weight_file)
if FLAGS.weight_checksum != "" and FLAGS.weight_checksum != weight_checksum:
print("Weight checksum mismatch: %s != %s" % (weight_checksum, FLAGS.weight_checksum))
return -1
if FLAGS.runtime == 'dsp':
print("DSP not support caffe model yet.")
return -1
if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else:
input_shape = [] input_shape = []
if FLAGS.input_shape != "": if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
output_graph_def = tf_converter_lib.convert_to_mace_pb( from lib.python.tools import caffe_converter_lib
input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node, output_graph_def = caffe_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
elif FLAGS.platform == 'tensorflow':
if FLAGS.runtime == 'dsp':
from lib.python.tools import tf_dsp_converter_lib
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else:
input_shape = []
if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
from lib.python.tools import tf_converter_lib
output_graph_def = tf_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else: else:
with gfile.GFile(FLAGS.output, "wb") as f: with open(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString()) f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f: with open(FLAGS.output + '_txt', "wb") as f:
# output_graph_def.ClearField('tensors') # output_graph_def.ClearField('tensors')
f.write(str(output_graph_def)) f.write(str(output_graph_def))
print("Model conversion is completed.") print("Model conversion is completed.")
...@@ -69,15 +84,25 @@ def parse_args(): ...@@ -69,15 +84,25 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true") parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument( parser.add_argument(
"--input", "--model_file",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load, Caffe prototxt file to load.")
parser.add_argument(
"--weight_file",
type=str, type=str,
default="", default="",
help="TensorFlow \'GraphDef\' file to load.") help="Caffe data file to load.")
parser.add_argument( parser.add_argument(
"--model_checksum", "--model_checksum",
type=str, type=str,
default="", default="",
help="Model file sha256 checksum") help="Model file sha256 checksum")
parser.add_argument(
"--weight_checksum",
type=str,
default="",
help="Weight file sha256 checksum")
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
...@@ -142,6 +167,11 @@ def parse_args(): ...@@ -142,6 +167,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="input shape.") help="input shape.")
parser.add_argument(
"--platform",
type=str,
default="tensorflow",
help="tensorflow/caffe")
parser.add_argument( parser.add_argument(
"--embed_model_data", "--embed_model_data",
type=str2bool, type=str2bool,
......
import struct
import os import os
import uuid import uuid
import numpy as np import numpy as np
import hashlib import hashlib
from tensorflow import gfile
from lib.proto import mace_pb2 from lib.proto import mace_pb2
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
...@@ -82,7 +80,6 @@ def rename_tensor(net_def): ...@@ -82,7 +80,6 @@ def rename_tensor(net_def):
class TensorInfo: class TensorInfo:
def __init__(self, id, t, runtime): def __init__(self, id, t, runtime):
self.id = id self.id = id
self.name = t.name
self.data_type = mace_pb2.DataType.Name(t.data_type) self.data_type = mace_pb2.DataType.Name(t.data_type)
if t.data_type == mace_pb2.DT_FLOAT: if t.data_type == mace_pb2.DT_FLOAT:
if runtime == 'gpu': if runtime == 'gpu':
...@@ -136,7 +133,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -136,7 +133,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
) )
model_data.extend(tensor_info.data) model_data.extend(tensor_info.data)
offset += len(tensor_info.data) offset += len(tensor_info.data)
with gfile.GFile(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f: with open(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f:
f.write(source) f.write(source)
counter += 1 counter += 1
...@@ -148,7 +145,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -148,7 +145,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
model_data_size = offset, model_data_size = offset,
model_data = model_data model_data = model_data
) )
with gfile.GFile(output_dir + 'tensor_data' + '.cc', "wb") as f: with open(output_dir + 'tensor_data' + '.cc', "wb") as f:
f.write(source) f.write(source)
if not embed_model_data: if not embed_model_data:
f = open(output_dir + model_tag + '.data', "wb") f = open(output_dir + model_tag + '.data', "wb")
...@@ -167,7 +164,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -167,7 +164,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
mode = 2, mode = 2,
runtime = runtime, runtime = runtime,
) )
with gfile.GFile(output_dir + 'op' + str(counter) + '.cc', "wb") as f: with open(output_dir + 'op' + str(counter) + '.cc', "wb") as f:
f.write(source) f.write(source)
counter += 1 counter += 1
...@@ -181,5 +178,5 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -181,5 +178,5 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
runtime = runtime, runtime = runtime,
model_pb_checksum = mode_pb_checksum model_pb_checksum = mode_pb_checksum
) )
with gfile.GFile(output, "wb") as f: with open(output, "wb") as f:
f.write(source) f.write(source)
...@@ -3,6 +3,7 @@ import tensorflow as tf ...@@ -3,6 +3,7 @@ import tensorflow as tf
import numpy as np import numpy as np
import math import math
import copy import copy
from tensorflow import gfile
from lib.python.tools import memory_optimizer from lib.python.tools import memory_optimizer
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import tensor_shape_pb2
...@@ -993,10 +994,15 @@ def add_shape_info(input_graph_def, input_node, input_shape): ...@@ -993,10 +994,15 @@ def add_shape_info(input_graph_def, input_node, input_shape):
return inputs_replaced_graph return inputs_replaced_graph
def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, data_type, device, winograd): def convert_to_mace_pb(model_file, input_node, input_shape, output_node, data_type, device, winograd):
net_def = mace_pb2.NetDef() net_def = mace_pb2.NetDef()
dt = data_type_map[data_type] dt = data_type_map[data_type]
input_graph_def = tf.GraphDef()
with gfile.Open(model_file, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
input_graph_def = add_shape_info(input_graph_def, input_node, input_shape) input_graph_def = add_shape_info(input_graph_def, input_node, input_shape)
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
...@@ -1006,7 +1012,7 @@ def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, da ...@@ -1006,7 +1012,7 @@ def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, da
converter.convert(input_node, output_node) converter.convert(input_node, output_node)
optimizer = Optimizer(net_def, device) optimizer = Optimizer(net_def, device)
net_def = optimizer.optimize() net_def = optimizer.optimize()
print "PB Converted." print "Model Converted."
if device == 'gpu': if device == 'gpu':
print "start optimize memory." print "start optimize memory."
mem_optimizer = memory_optimizer.MemoryOptimizer(net_def) mem_optimizer = memory_optimizer.MemoryOptimizer(net_def)
......
from lib.proto import mace_pb2 from lib.proto import mace_pb2
import tensorflow as tf import tensorflow as tf
from tensorflow import gfile
from operator import mul from operator import mul
from dsp_ops import DspOps from dsp_ops import DspOps
from lib.python.tools import graph_util from lib.python.tools import graph_util
...@@ -359,12 +360,17 @@ def fuse_quantize(net_def, input_node, output_node): ...@@ -359,12 +360,17 @@ def fuse_quantize(net_def, input_node, output_node):
new_net_def.op.extend(new_ops) new_net_def.op.extend(new_ops)
return new_net_def return new_net_def
def convert_to_mace_pb(input_graph_def, input_node, output_node, dsp_mode): def convert_to_mace_pb(model_file, input_node, output_node, dsp_mode):
""" """
nnlib does not have batch norm, so use tensorflow optimizer to fold nnlib does not have batch norm, so use tensorflow optimizer to fold
batch norm with convolution. The fold optimization reorders ops, so batch norm with convolution. The fold optimization reorders ops, so
we sort ops first by topology. we sort ops first by topology.
""" """
input_graph_def = tf.GraphDef()
with gfile.Open(model_file, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
input_graph_def = graph_util.sort_tf_graph(input_graph_def) input_graph_def = graph_util.sort_tf_graph(input_graph_def)
net_def = mace_pb2.NetDef() net_def = mace_pb2.NetDef()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册