提交 0b75406f 编写于 作者: L liuqi

Fix tf_converter to support gcn model.

上级 6c709f82
......@@ -24,7 +24,7 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize)
else:
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def)
input_graph_def, FLAGS.runtime)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
......
......@@ -13,6 +13,12 @@ pooling_type_mode = {
'MaxPool': 2
}
buffer_type_map = {
'FILTER' : 0,
'IN_OUT' : 1,
'ARGUMENT' : 2,
}
def convert_tensor(op, tensor):
tf_tensor = op.outputs[0].eval()
tensor.name = op.outputs[0].name
......@@ -36,7 +42,21 @@ def get_input_tensor(op, index):
input_tensor = get_input_tensor(input_tensor.op, 0)
return input_tensor
def convert_ops(ops_map, unresolved_ops, net_def):
def add_buffer_to_image(input_name, input_type, net_def):
output_name = input_name[:-2] + "_b2i" + input_name[-2:]
op_def = net_def.op.add()
op_def.name = output_name
op_def.type = 'BufferToImage'
op_def.input.extend([input_name])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map[input_type]
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'mode'
epsilon_arg.i = 0
return output_name
def convert_ops(unresolved_ops, net_def, device):
ops_count = len(unresolved_ops)
resolved_count = 1
......@@ -54,14 +74,13 @@ def convert_ops(ops_map, unresolved_ops, net_def):
op_def.type = 'DepthwiseConv2d'
else:
op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
if device == 'gpu':
op_def.input.extend([first_op.inputs[0].name])
output_name = add_buffer_to_image(first_op.inputs[1].name, "FILTER", net_def)
op_def.input.extend([output_name])
else:
op_def.input.extend([input.name for input in first_op.inputs])
padding_arg = op_def.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode[first_op.get_attr('padding')]
......@@ -71,27 +90,53 @@ def convert_ops(ops_map, unresolved_ops, net_def):
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NHWC'
final_op = first_op
if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd':
bias_add_op = unresolved_ops[1]
op_def.input.extend([bias_add_op.inputs[1].name])
resolved_count = 2
if ops_count >= 3 and unresolved_ops[1].type == 'Const' and unresolved_ops[2].type == 'BiasAdd' :
bias_tensor = unresolved_ops[1]
tensor = net_def.tensors.add()
convert_tensor(bias_tensor, tensor)
if ops_count >= 3 and unresolved_ops[1].type == 'Relu':
op_def.name = "FusedConv2D"
bias_add_op = unresolved_ops[2]
if device == 'gpu':
output_name = add_buffer_to_image(bias_add_op.inputs[1].name, "ARGUMENT", net_def)
op_def.input.extend([output_name])
else:
op_def.input.extend([bias_add_op.inputs[1].name])
final_op = bias_add_op
resolved_count = 3
elif first_op.type == 'FusedBatchNorm':
op_def = net_def.op.add()
op_def.name = first_op.name
first_op.type = 'BatchNorm'
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
if ops_count >= 4 and unresolved_ops[3].type == 'Relu':
relu_op = unresolved_ops[3];
op_def.type = "FusedConv2D"
final_op = relu_op
resolved_count = 4
op_def.output.extend([output.name for output in final_op.outputs])
output_shapes = []
for output in first_op.outputs:
for output in final_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type == 'FusedBatchNorm':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = 'BatchNorm'
if device == 'gpu':
op_def.input.extend([first_op.inputs[0].name])
for i in range(1, len(first_op.inputs)):
output_name = add_buffer_to_image(first_op.inputs[i].name, "ARGUMENT", net_def)
op_def.input.extend([output_name])
else:
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([first_op.outputs[0].name])
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(first_op.outputs[0].shape.as_list())
op_def.output_shape.extend([output_shape])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'epsilon'
epsilon_arg.f = first_op.get_attr('epsilon')
......@@ -190,7 +235,7 @@ def convert_ops(ops_map, unresolved_ops, net_def):
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = "Concat"
op_def.input.extend([first_op.inputs[i] for i in xrange(2)])
op_def.input.extend([first_op.inputs[i].name for i in xrange(2)])
op_def.output.extend([output.name for output in first_op.outputs])
axis_arg = op_def.arg.add()
axis_arg.name = 'axis'
......@@ -205,19 +250,21 @@ def convert_ops(ops_map, unresolved_ops, net_def):
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = "ResizeBilinear"
op_def.input.extend(first_op.inputs[0])
op_def.input.extend([first_op.inputs[0].name])
op_def.output.extend([output.name for output in first_op.outputs])
size_arg = op_def.arg.add()
size_arg.name = 'size'
size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat)
size_arg = op_def.arg.add()
size_arg.name = 'align_corners'
size_arg.ints.extend(first_op.get_attr('align_corners'))
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type in ['Relu', 'SpaceToBatchND',
'BatchToSpaceND', 'BiasAdd', 'FusedBatchNorm']:
elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND', 'BiasAdd']:
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = first_op.type
......@@ -237,24 +284,17 @@ def convert_ops(ops_map, unresolved_ops, net_def):
del unresolved_ops[0]
def convert_to_mace_pb(input_graph_def):
def convert_to_mace_pb(input_graph_def, device):
net_def = mace_pb2.NetDef()
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
ops_map = {}
for op in ops:
if op.name not in ops_map.keys():
ops_map[op.name] = op
else:
raise ValueError("Duplicate op names detected for ", op.name)
unresolved_ops = ops
while len(unresolved_ops) > 0:
convert_ops(ops_map, unresolved_ops, net_def)
convert_ops(unresolved_ops, net_def, device)
print "Done."
print "PB Parsed."
return net_def
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册