提交 4b600b8a 编写于 作者: 李寅

Integrate tensorflow transform_graph to mace

上级 a46158c8
...@@ -99,54 +99,6 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp ...@@ -99,54 +99,6 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp
Prepare your pre-trained TensorFlow model.pb file. Prepare your pre-trained TensorFlow model.pb file.
Use `Graph Transform Tool <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md>`__
to optimize your model for inference.
This tool will improve the efficiency of inference by making several optimizations like operators
folding, redundant node removal etc. We strongly recommend MACE users to use it before building.
Usage for CPU/GPU,
.. code:: bash
# CPU/GPU:
./transform_graph \
--in_graph=/path/to/your/tf_model.pb \
--out_graph=/path/to/your/output/tf_model_opt.pb \
--inputs='input node name' \
--outputs='output node name' \
--transforms='strip_unused_nodes(type=float, shape="1,64,64,3")
strip_unused_nodes(type=float, shape="1,64,64,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
flatten_atrous_conv
fold_batch_norms
fold_old_batch_norms
remove_control_dependencies
strip_unused_nodes
sort_by_execution_order'
Usage for DSP,
.. code:: bash
# DSP:
./transform_graph \
--in_graph=/path/to/your/tf_model.pb \
--out_graph=/path/to/your/output/tf_model_opt.pb \
--inputs='input node name' \
--outputs='output node name' \
--transforms='strip_unused_nodes(type=float, shape="1,64,64,3")
strip_unused_nodes(type=float, shape="1,64,64,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
backport_concatv2
quantize_weights(minimum_size=2)
quantize_nodes
remove_control_dependencies
strip_unused_nodes
sort_by_execution_order'
- Caffe - Caffe
......
...@@ -118,6 +118,8 @@ def main(unused_args): ...@@ -118,6 +118,8 @@ def main(unused_args):
option.quantize_range_file = FLAGS.quantize_range_file option.quantize_range_file = FLAGS.quantize_range_file
option.change_concat_ranges = FLAGS.change_concat_ranges option.change_concat_ranges = FLAGS.change_concat_ranges
option.cl_mem_type = FLAGS.cl_mem_type option.cl_mem_type = FLAGS.cl_mem_type
option.device = device_type_map[FLAGS.runtime]
option.data_type = parse_data_type(FLAGS.data_type, option.device)
input_node_names = FLAGS.input_node.split(',') input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':') input_node_shapes = FLAGS.input_shape.split(':')
...@@ -192,10 +194,6 @@ def main(unused_args): ...@@ -192,10 +194,6 @@ def main(unused_args):
exit(1) exit(1)
output_graph_def = converter.run() output_graph_def = converter.run()
option.device = device_type_map[FLAGS.runtime]
option.data_type = parse_data_type(
FLAGS.data_type, option.device)
mace_transformer = transformer.Transformer( mace_transformer = transformer.Transformer(
option, output_graph_def) option, output_graph_def)
output_graph_def, quantize_activation_info = mace_transformer.run() output_graph_def, quantize_activation_info = mace_transformer.run()
......
...@@ -34,6 +34,7 @@ from mace.python.tools.converter_tool.base_converter import ConverterUtil ...@@ -34,6 +34,7 @@ from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check from mace.python.tools.convert_util import mace_check
from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph
tf_padding_str = 'padding' tf_padding_str = 'padding'
tf_strides_str = 'strides' tf_strides_str = 'strides'
...@@ -114,6 +115,40 @@ TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) ...@@ -114,6 +115,40 @@ TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
TFSupportedOps = [six.b(op) for op in TFSupportedOps] TFSupportedOps = [six.b(op) for op in TFSupportedOps]
TFTransformGraphOptions = {
base_converter.DeviceType.CPU.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
],
base_converter.DeviceType.GPU.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'flatten_atrous_conv',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
],
base_converter.DeviceType.HEXAGON.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
]
}
class TensorflowConverter(base_converter.ConverterInterface): class TensorflowConverter(base_converter.ConverterInterface):
"""A class for convert tensorflow frozen model to mace model. """A class for convert tensorflow frozen model to mace model.
...@@ -233,13 +268,20 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -233,13 +268,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._placeholders = {} self._placeholders = {}
self.add_shape_info(tf_graph_def) self.add_shape_info(tf_graph_def)
print("Run transform_graph: %s" % TFTransformGraphOptions[
option.device])
transformed_graph_def = TransformGraph(tf_graph_def,
option.input_nodes.keys(),
option.output_nodes.keys(),
TFTransformGraphOptions[
option.device])
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(tf_graph_def, name='') tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph self._tf_graph = graph
self._skip_tensor = set() self._skip_tensor = set()
self._output_shape_list = [] self._output_shape_list = []
self._output_shape_op_list = [] self._output_shape_op_list = []
......
...@@ -1182,43 +1182,34 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1182,43 +1182,34 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op: for op in net.op:
if op.type == MaceOp.Softmax.name: if op.type == MaceOp.Softmax.name:
# see if possible to fold # see if possible to fold
# Reshape(xd->2d) + Softmax(2d) + Reshape(xd) to Softmax(xd) # Reshape(xd->2d) + Softmax(2d) [+ Reshape(xd)] to Softmax(xd)
should_fold = False should_fold = False
if op.input[0] in self._producer \ if op.input[0] in self._producer \
and self._producer[op.input[0]].type \ and self._producer[op.input[0]].type \
== MaceOp.Reshape.name \ == MaceOp.Reshape.name \
and len(op.output_shape[0].dims) == 2 \ and len(op.output_shape[0].dims) == 2:
and self.consumer_count(op.output[0]) == 1: should_fold = True
producer = self._producer[op.input[0]]
consumer = self._consumers[op.output[0]][0]
if (consumer.type == MaceOp.Reshape.name
and op.output_shape[0].dims[-1]
== consumer.output_shape[0].dims[-1]
and op.output_shape[0].dims[-1] != -1
and self.get_tensor_shape(producer.input[0])
== consumer.output_shape[0].dims):
should_fold = True
if should_fold: if should_fold:
print( print(
"Fold reshape and softmax: %s(%s)" "Fold reshape and softmax: %s(%s)"
% (op.name, op.type)) % (op.name, op.type))
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
consumer = self._consumers[op.output[0]][0]
op.output_shape[0].dims[:] = self.get_tensor_shape( op.output_shape[0].dims[:] = self.get_tensor_shape(
producer.input[0]) producer.input[0])
# if there is a shape op, remove it too if op.output[0] in self._consumers:
if (consumer.input[1] in self._producer consumer = self._consumers[op.output[0]][0]
and self._producer[consumer.input[1]].type # if there is a shape op, remove it too
== 'Shape'): if (consumer.input[1] in self._producer
self.safe_remove_node( and self._producer[consumer.input[1]].type
self._producer[consumer.input[1]], None, == 'Shape'):
remove_input_tensor=True) self.safe_remove_node(
# remove consumer reshape self._producer[consumer.input[1]], None,
self.safe_remove_node(consumer, op, remove_input_tensor=True)
remove_input_tensor=True) # remove consumer reshape
self.safe_remove_node(consumer, op,
remove_input_tensor=True)
# remove producer reshape # remove producer reshape
self.safe_remove_node(producer, self.safe_remove_node(producer,
self._producer.get(producer.input[0], self._producer.get(producer.input[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册