提交 ab6e02aa 编写于 作者: 吴承辉

Merge branch 'master' into 'master'

Refactor mace dsp converter

See merge request !102
...@@ -3,7 +3,7 @@ py_library( ...@@ -3,7 +3,7 @@ py_library(
srcs = [ srcs = [
"tf_converter_lib.py", "tf_converter_lib.py",
"tf_dsp_converter_lib.py", "tf_dsp_converter_lib.py",
"tf_graph_util.py"], "graph_util.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//mace/proto:mace_py", "//mace/proto:mace_py",
......
import tensorflow as tf import tensorflow as tf
from mace.proto import mace_pb2
from collections import OrderedDict from collections import OrderedDict
def sort_graph_node(node, nodes_map, ordered_nodes_map): def sort_tf_node(node, nodes_map, ordered_nodes_map):
if node.name not in ordered_nodes_map: if node.name not in ordered_nodes_map:
for input_tensor_name in node.input: for input_tensor_name in node.input:
input_node_name = input_tensor_name.split(':')[ input_node_name = input_tensor_name.split(':')[
...@@ -10,17 +11,40 @@ def sort_graph_node(node, nodes_map, ordered_nodes_map): ...@@ -10,17 +11,40 @@ def sort_graph_node(node, nodes_map, ordered_nodes_map):
continue continue
input_node = nodes_map[input_node_name] input_node = nodes_map[input_node_name]
sort_graph_node(input_node, nodes_map, ordered_nodes_map) sort_tf_node(input_node, nodes_map, ordered_nodes_map)
ordered_nodes_map[input_node_name] = input_node
ordered_nodes_map[node.name] = node ordered_nodes_map[node.name] = node
def sort_graph(graph_def): def sort_tf_graph(graph_def):
nodes_map = {} nodes_map = {}
ordered_nodes_map = OrderedDict() ordered_nodes_map = OrderedDict()
for node in graph_def.node: for node in graph_def.node:
nodes_map[node.name] = node nodes_map[node.name] = node
for node in graph_def.node: for node in graph_def.node:
sort_graph_node(node, nodes_map, ordered_nodes_map) sort_tf_node(node, nodes_map, ordered_nodes_map)
sorted_graph = tf.GraphDef() sorted_graph = tf.GraphDef()
sorted_graph.node.extend([node for _, node in ordered_nodes_map.iteritems()]) sorted_graph.node.extend([node for node in ordered_nodes_map.values()])
return sorted_graph
def sort_mace_node(node, nodes_map, ordered_nodes_map):
if node.name not in ordered_nodes_map:
for input_tensor_name in node.input:
input_node_name = input_tensor_name.split(':')[
0] if ':' in input_tensor_name else input_tensor_name
if input_node_name not in nodes_map or input_node_name in ordered_nodes_map:
continue
input_node = nodes_map[input_node_name]
sort_mace_node(input_node, nodes_map, ordered_nodes_map)
ordered_nodes_map[node.name] = node
def sort_mace_graph(graph_def, output_name):
nodes_map = {}
ordered_nodes_map = OrderedDict()
for node in graph_def.op:
nodes_map[node.name] = node
sort_mace_node(nodes_map[output_name], nodes_map, ordered_nodes_map)
sorted_graph = mace_pb2.NetDef()
sorted_graph.tensors.extend(graph_def.tensors)
sorted_graph.op.extend([node for node in ordered_nodes_map.values()])
return sorted_graph return sorted_graph
\ No newline at end of file
...@@ -21,7 +21,7 @@ def main(unused_args): ...@@ -21,7 +21,7 @@ def main(unused_args):
if FLAGS.runtime == 'dsp': if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_dim, FLAGS.output_node) input_graph_def, FLAGS.input_node, FLAGS.output_node)
else: else:
output_graph_def = tf_converter_lib.convert_to_mace_pb( output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def) input_graph_def)
...@@ -53,10 +53,10 @@ def parse_args(): ...@@ -53,10 +53,10 @@ def parse_args():
default="cpu", default="cpu",
help="Runtime: cpu/gpu/dsp.") help="Runtime: cpu/gpu/dsp.")
parser.add_argument( parser.add_argument(
"--input_dim", "--input_node",
type=str, type=str,
default="input_node,1,28,28,3", default="input_node",
help="e.g., input_node,1,28,28,3") help="e.g., input_node")
parser.add_argument( parser.add_argument(
"--output_node", "--output_node",
type=str, type=str,
......
from mace.proto import mace_pb2 from mace.proto import mace_pb2
# import mace_pb2
import tensorflow as tf import tensorflow as tf
import numpy as np
from operator import mul from operator import mul
from dsp_ops import DspOps from dsp_ops import DspOps
from mace.python.tools import tf_graph_util from mace.python.tools import graph_util
# converter --input ../libcv/quantized_icnet.pb --output quantized_icnet_dsp.pb \
# --runtime dsp --input_dim input_node,1,480,480,3 --output_node icnet/output_node
padding_mode = { padding_mode = {
'NA': 0, 'NA': 0,
...@@ -15,9 +16,17 @@ padding_mode = { ...@@ -15,9 +16,17 @@ padding_mode = {
'SAME_CAFFE': 5 'SAME_CAFFE': 5
} }
node_count = 0 def get_tensor_name_from_op(op_name, port):
node_ids = {} return op_name + ':' + str(port)
resolved_ops = set()
def get_node_from_map(op_map, op_or_tensor_name):
op_name = op_or_tensor_name.split(':')[0]
return op_map[op_name]
def get_op_and_port_from_tensor(tensor_name):
op, port = tensor_name.split(':')
port = int(port)
return op, port
def max_elem_size(tensor): def max_elem_size(tensor):
if len(tensor.shape.as_list()) == 0: if len(tensor.shape.as_list()) == 0:
...@@ -49,28 +58,15 @@ def get_input_tensor(op, index): ...@@ -49,28 +58,15 @@ def get_input_tensor(op, index):
def add_shape_const_node(net_def, op, values, name): def add_shape_const_node(net_def, op, values, name):
print ('Add const node: ', op.name + '/' + name) print ('Add const node: ', op.name + '/' + name)
global node_count
tensor = net_def.tensors.add() tensor = net_def.tensors.add()
node_name = op.name + '/' + name node_name = op.name + '/' + name
tensor.name = node_name + ':0' tensor.name = node_name + ':0'
tensor.node_id = node_count
node_count += 1
register_node_id(node_name, tensor.node_id)
tensor.data_type = mace_pb2.DT_INT32 tensor.data_type = mace_pb2.DT_INT32
tensor.dims.extend(values) tensor.dims.extend(values)
return tensor.name return tensor.name
def register_node_id(node_name, node_id): def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
global node_ids
node_ids[node_name] = node_id
def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
global node_count
ops_count = len(unresolved_ops)
resolved_count = 1
first_op = unresolved_ops[0] first_op = unresolved_ops[0]
print ('Op: ', first_op.name, first_op.type, first_op.outputs[0].shape) print ('Op: ', first_op.name, first_op.type, first_op.outputs[0].shape)
if first_op.name in resolved_ops: if first_op.name in resolved_ops:
...@@ -81,9 +77,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops): ...@@ -81,9 +77,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
tf_tensor = first_op.outputs[0].eval() tf_tensor = first_op.outputs[0].eval()
tensor = net_def.tensors.add() tensor = net_def.tensors.add()
tensor.name = first_op.outputs[0].name tensor.name = first_op.outputs[0].name
tensor.node_id = node_count
node_count += 1
register_node_id(tensor.name.split(':')[0], tensor.node_id)
tensor.data_type = find_dtype(first_op.outputs[0].dtype) tensor.data_type = find_dtype(first_op.outputs[0].dtype)
shape = list(tf_tensor.shape) shape = list(tf_tensor.shape)
if len(shape) > 0: if len(shape) > 0:
...@@ -101,8 +94,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops): ...@@ -101,8 +94,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
op_def.type = dsp_ops.map_nn_op(first_op.type) op_def.type = dsp_ops.map_nn_op(first_op.type)
op_def.node_id = node_count
node_count += 1
op_def.padding = padding_mode['NA'] op_def.padding = padding_mode['NA']
if len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \ if len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \
...@@ -129,7 +120,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops): ...@@ -129,7 +120,6 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
op_def.input.extend([t.name for t in s2b_op.inputs[1:]]) op_def.input.extend([t.name for t in s2b_op.inputs[1:]])
op_def.input.extend([min_tensor.name, max_tensor.name]) op_def.input.extend([min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_op.outputs]) op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_op.outputs])
elif has_padding_and_strides(first_op): elif has_padding_and_strides(first_op):
op_def.padding = padding_mode[first_op.get_attr('padding')] op_def.padding = padding_mode[first_op.get_attr('padding')]
op_def.input.extend([t.name for t in first_op.inputs]) op_def.input.extend([t.name for t in first_op.inputs])
...@@ -148,91 +138,170 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops): ...@@ -148,91 +138,170 @@ def convert_ops(unresolved_ops, net_def, output_node, dsp_ops):
elif dsp_ops.has_op(first_op.type): elif dsp_ops.has_op(first_op.type):
op_def.input.extend([t.name for t in first_op.inputs]) op_def.input.extend([t.name for t in first_op.inputs])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in first_op.outputs]) op_def.out_max_byte_size.extend([max_elem_size(out) for out in first_op.outputs])
if first_op.type == 'Placeholder':
input_info = net_def.input_info.add()
input_info.name = op_def.name
input_info.node_id = op_def.node_id
input_info.dims.extend(first_op.outputs[0].shape.as_list())
input_info.max_byte_size = max_elem_size(first_op.outputs[0])
input_info.data_type = find_dtype(first_op.outputs[0].dtype)
elif first_op.name == output_node:
output_info = net_def.output_info.add()
output_info.name = op_def.name
output_info.node_id = op_def.node_id
output_info.dims.extend(first_op.outputs[0].shape.as_list())
output_info.max_byte_size = max_elem_size(first_op.outputs[0])
output_info.data_type = find_dtype(first_op.outputs[0].dtype)
else: else:
raise Exception('Unsupported op: ', first_op) raise Exception('Unsupported op: ', first_op)
register_node_id(op_def.name, op_def.node_id)
print ('Add op node: ', first_op.name)
for t in op_def.input:
node, port = t.split(':')
node_id = node_ids[node]
node_input = op_def.node_input.add()
node_input.node_id = node_id
node_input.output_port = int(port)
resolved_ops.add(first_op.name) resolved_ops.add(first_op.name)
for i in range(resolved_count): del unresolved_ops[0]
del unresolved_ops[0]
def add_output_node(net_def, output_node): def add_output_node(net_def, output_node):
global node_count
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = 'output' op_def.name = 'output'
op_def.type = 'OUTPUT' op_def.type = 'OUTPUT'
op_def.node_id = node_count op_def.input.extend([get_tensor_name_from_op(output_node, 0)])
node_count += 1
register_node_id(op_def.name, op_def.node_id) def reverse_batch_to_space_and_biasadd(net_def):
op_def.input.extend([output_node + ':0']) tensor_map = {}
node_input = op_def.node_input.add() for tensor in net_def.tensors:
node_input.node_id = node_ids[output_node] tensor_map[tensor.name] = tensor
node_input.output_port = 0 op_map = {}
for op in net_def.op:
def convert_to_mace_pb(input_graph_def, input_dim, output_node): op_map[op.name] = op
consumers = {}
for op in net_def.op:
for ipt in op.input:
if ipt not in consumers:
consumers[ipt] = []
consumers[ipt].append(op)
new_ops = []
skip_ops = set()
visited_ops = set()
for op in net_def.op:
if op.name in visited_ops:
pass
# pattern: QConv -> RR -> R -> QB2S -> QBiasAdd -> RR -> R
success = False
if op.type == 'Requantize_32to8':
biasadd_requantize_op = op
biasadd_op = get_node_from_map(op_map, biasadd_requantize_op.input[0])
if biasadd_op.type == 'QuantizedBiasAdd_8p8to32':
b2s_op = get_node_from_map(op_map, biasadd_op.input[0])
if b2s_op.type == 'QuantizedBatchToSpaceND_8':
conv_requantize_op = get_node_from_map(op_map, b2s_op.input[0])
conv_op = get_node_from_map(op_map, conv_requantize_op.input[0])
if conv_op.type == 'QuantizedConv2d_8x8to32':
new_biasadd_op = mace_pb2.OperatorDef()
new_biasadd_op.CopyFrom(biasadd_op)
new_biasadd_op.input[0] = get_tensor_name_from_op(conv_requantize_op.name, 0)
new_biasadd_op.input[2] = get_tensor_name_from_op(conv_requantize_op.name, 1)
new_biasadd_op.input[3] = get_tensor_name_from_op(conv_requantize_op.name, 2)
new_biasadd_op.out_max_byte_size[0] = conv_requantize_op.out_max_byte_size[0] * 4
new_biasadd_requantize_op = mace_pb2.OperatorDef()
new_biasadd_requantize_op.CopyFrom(biasadd_requantize_op)
new_biasadd_requantize_op.out_max_byte_size[0] = new_biasadd_op.out_max_byte_size[0] / 4
new_b2s_op = mace_pb2.OperatorDef()
new_b2s_op.CopyFrom(b2s_op)
new_b2s_op.input[0] = get_tensor_name_from_op(biasadd_requantize_op.name, 0)
new_b2s_op.input[3] = get_tensor_name_from_op(biasadd_requantize_op.name, 1)
new_b2s_op.input[4] = get_tensor_name_from_op(biasadd_requantize_op.name, 2)
new_ops.extend([new_biasadd_op, new_biasadd_requantize_op, new_b2s_op])
skip_ops = skip_ops.union([biasadd_op.name, biasadd_requantize_op.name, b2s_op.name])
visited_ops.add(op.name)
follow_ops = consumers[get_tensor_name_from_op(biasadd_requantize_op.name, 0)]
for follow_op in follow_ops:
new_follow_op = mace_pb2.OperatorDef()
new_follow_op.CopyFrom(follow_op)
for i in range(len(follow_op.input)):
for k in range(3):
if new_follow_op.input[i] == get_tensor_name_from_op(biasadd_requantize_op.name, k):
new_follow_op.input[i] = get_tensor_name_from_op(b2s_op.name, k)
new_ops.append(new_follow_op)
skip_ops.add(follow_op.name)
visited_ops.add(follow_op.name)
visited_ops.add(op.name)
new_net_def = mace_pb2.NetDef()
new_net_def.tensors.extend(tensor_map.values())
for op in net_def.op:
if op.name not in skip_ops:
new_net_def.op.extend([op])
new_net_def.op.extend(new_ops)
return new_net_def
def add_node_id(net_def):
node_id_counter = 0
node_id_map = {}
for tensor in net_def.tensors:
tensor.node_id = node_id_counter
node_id_counter += 1
tensor_op, port = get_op_and_port_from_tensor(tensor.name)
node_id_map[tensor_op] = tensor.node_id
for op in net_def.op:
op.node_id = node_id_counter
node_id_counter += 1
node_id_map[op.name] = op.node_id
for ipt in op.input:
op_name, port = get_op_and_port_from_tensor(ipt)
node_id = node_id_map[op_name]
node_input = op.node_input.add()
node_input.node_id = node_id
node_input.output_port = int(port)
return net_def
def add_input_output_info(net_def, input_node, output_node, graph):
input_tensor = graph.get_tensor_by_name(get_tensor_name_from_op(input_node, 0))
output_tensor = graph.get_tensor_by_name(get_tensor_name_from_op(output_node, 0))
for op in net_def.op:
if op.name == input_node:
input_info = net_def.input_info.add()
input_info.name = op.name
input_info.node_id = op.node_id
input_info.dims.extend(input_tensor.shape.as_list())
input_info.max_byte_size = max_elem_size(input_tensor)
input_info.data_type = find_dtype(input_tensor.dtype)
elif op.name == output_node:
output_info = net_def.output_info.add()
output_info.name = op.name
output_info.node_id = op.node_id
output_info.dims.extend(output_tensor.shape.as_list())
output_info.max_byte_size = max_elem_size(output_tensor)
output_info.data_type = find_dtype(output_tensor.dtype)
return net_def
def convert_to_mace_pb(input_graph_def, input_node, output_node):
""" """
nnlib does not have batch norm, so use tensorflow optimizer to fold nnlib does not have batch norm, so use tensorflow optimizer to fold
batch norm with convolution. The fold optimization reorders ops, so batch norm with convolution. The fold optimization reorders ops, so
we sort ops first by topology. we sort ops first by topology.
""" """
input_graph_def = tf_graph_util.sort_graph(input_graph_def) input_graph_def = graph_util.sort_tf_graph(input_graph_def)
inputs = input_dim.split(';')
input_shape = {}
for input in inputs:
input_name_shape = input.split(',')
name = input_name_shape[0]
shape = [int(d) for d in input_name_shape[1:]]
input_shape[name] = shape
net_def = mace_pb2.NetDef() net_def = mace_pb2.NetDef()
for node in input_graph_def.node:
if node.op == 'Placeholder':
node.attr['shape'].shape.unknown_rank = False
for d in input_shape[node.name]:
dim = node.attr['shape'].shape.dim.add()
dim.size = d
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="") tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations() ops = graph.get_operations()
dsp_ops = DspOps() dsp_ops = DspOps()
resolved_ops = set()
# convert const node # convert const node
unresolved_ops = [op for op in ops if op.type == 'Const'] unresolved_ops = [op for op in ops if op.type == 'Const']
while len(unresolved_ops) > 0: while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, net_def, output_node, dsp_ops) convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops)
# convert op node # convert op node
unresolved_ops = [op for op in ops if op.type != 'Const'] unresolved_ops = [op for op in ops if op.type != 'Const']
while len(unresolved_ops) > 0: while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, net_def, output_node, dsp_ops) convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops)
add_output_node(net_def, output_node) add_output_node(net_def, output_node)
# optimized_net_def = reverse_batch_to_space_and_biasadd(net_def)
# sorted_net_def = graph_util.sort_mace_graph(optimized_net_def, output_node)
net_def_with_node_id = add_node_id(net_def)
final_net_def = add_input_output_info(net_def_with_node_id, input_node, output_node, graph)
return final_net_def
return net_def
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册