提交 4b600b8a 编写于 作者: 李寅

Integrate tensorflow transform_graph to mace

上级 a46158c8
......@@ -99,54 +99,6 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp
Prepare your pre-trained TensorFlow model.pb file.
Use `Graph Transform Tool <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md>`__
to optimize your model for inference.
This tool will improve the efficiency of inference by making several optimizations like operators
folding, redundant node removal etc. We strongly recommend MACE users to use it before building.
Usage for CPU/GPU,
.. code:: bash
# CPU/GPU:
./transform_graph \
--in_graph=/path/to/your/tf_model.pb \
--out_graph=/path/to/your/output/tf_model_opt.pb \
--inputs='input node name' \
--outputs='output node name' \
--transforms='strip_unused_nodes(type=float, shape="1,64,64,3")
strip_unused_nodes(type=float, shape="1,64,64,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
flatten_atrous_conv
fold_batch_norms
fold_old_batch_norms
remove_control_dependencies
strip_unused_nodes
sort_by_execution_order'
Usage for DSP,
.. code:: bash
# DSP:
./transform_graph \
--in_graph=/path/to/your/tf_model.pb \
--out_graph=/path/to/your/output/tf_model_opt.pb \
--inputs='input node name' \
--outputs='output node name' \
--transforms='strip_unused_nodes(type=float, shape="1,64,64,3")
strip_unused_nodes(type=float, shape="1,64,64,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
backport_concatv2
quantize_weights(minimum_size=2)
quantize_nodes
remove_control_dependencies
strip_unused_nodes
sort_by_execution_order'
- Caffe
......
......@@ -118,6 +118,8 @@ def main(unused_args):
option.quantize_range_file = FLAGS.quantize_range_file
option.change_concat_ranges = FLAGS.change_concat_ranges
option.cl_mem_type = FLAGS.cl_mem_type
option.device = device_type_map[FLAGS.runtime]
option.data_type = parse_data_type(FLAGS.data_type, option.device)
input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':')
......@@ -192,10 +194,6 @@ def main(unused_args):
exit(1)
output_graph_def = converter.run()
option.device = device_type_map[FLAGS.runtime]
option.data_type = parse_data_type(
FLAGS.data_type, option.device)
mace_transformer = transformer.Transformer(
option, output_graph_def)
output_graph_def, quantize_activation_info = mace_transformer.run()
......
......@@ -34,6 +34,7 @@ from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph
tf_padding_str = 'padding'
tf_strides_str = 'strides'
......@@ -114,6 +115,40 @@ TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
TFSupportedOps = [six.b(op) for op in TFSupportedOps]
TFTransformGraphOptions = {
base_converter.DeviceType.CPU.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
],
base_converter.DeviceType.GPU.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'flatten_atrous_conv',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
],
base_converter.DeviceType.HEXAGON.value: [
'strip_unused_nodes',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants(ignore_errors=true)',
'fold_batch_norms',
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'sort_by_execution_order'
]
}
class TensorflowConverter(base_converter.ConverterInterface):
"""A class for convert tensorflow frozen model to mace model.
......@@ -233,13 +268,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._placeholders = {}
self.add_shape_info(tf_graph_def)
print("Run transform_graph: %s" % TFTransformGraphOptions[
option.device])
transformed_graph_def = TransformGraph(tf_graph_def,
option.input_nodes.keys(),
option.output_nodes.keys(),
TFTransformGraphOptions[
option.device])
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(tf_graph_def, name='')
tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph
self._skip_tensor = set()
self._output_shape_list = []
self._output_shape_op_list = []
......
......@@ -1182,43 +1182,34 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if op.type == MaceOp.Softmax.name:
# see if possible to fold
# Reshape(xd->2d) + Softmax(2d) + Reshape(xd) to Softmax(xd)
# Reshape(xd->2d) + Softmax(2d) [+ Reshape(xd)] to Softmax(xd)
should_fold = False
if op.input[0] in self._producer \
and self._producer[op.input[0]].type \
== MaceOp.Reshape.name \
and len(op.output_shape[0].dims) == 2 \
and self.consumer_count(op.output[0]) == 1:
producer = self._producer[op.input[0]]
consumer = self._consumers[op.output[0]][0]
if (consumer.type == MaceOp.Reshape.name
and op.output_shape[0].dims[-1]
== consumer.output_shape[0].dims[-1]
and op.output_shape[0].dims[-1] != -1
and self.get_tensor_shape(producer.input[0])
== consumer.output_shape[0].dims):
should_fold = True
and len(op.output_shape[0].dims) == 2:
should_fold = True
if should_fold:
print(
"Fold reshape and softmax: %s(%s)"
% (op.name, op.type))
producer = self._producer[op.input[0]]
consumer = self._consumers[op.output[0]][0]
op.output_shape[0].dims[:] = self.get_tensor_shape(
producer.input[0])
# if there is a shape op, remove it too
if (consumer.input[1] in self._producer
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape
self.safe_remove_node(consumer, op,
remove_input_tensor=True)
if op.output[0] in self._consumers:
consumer = self._consumers[op.output[0]][0]
# if there is a shape op, remove it too
if (consumer.input[1] in self._producer
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape
self.safe_remove_node(consumer, op,
remove_input_tensor=True)
# remove producer reshape
self.safe_remove_node(producer,
self._producer.get(producer.input[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册