提交 6c709f82 编写于 作者: L liuqi

Converter: support FusedBatchNorm, concat and resizebilinear.

上级 5c1264b3
......@@ -30,14 +30,13 @@ def convert_tensor(op, tensor):
else:
raise Exception("Not supported tensor type: " + tf_dt.name)
def get_input_tensor(op, index):
input_tensor = op.inputs[index]
if input_tensor.op.type == 'Reshape':
input_tensor = get_input_tensor(input_tensor.op, 0)
return input_tensor
def convert_ops(unresolved_ops, net_def):
def convert_ops(ops_map, unresolved_ops, net_def):
ops_count = len(unresolved_ops)
resolved_count = 1
......@@ -77,6 +76,28 @@ def convert_ops(unresolved_ops, net_def):
bias_add_op = unresolved_ops[1]
op_def.input.extend([bias_add_op.inputs[1].name])
resolved_count = 2
if ops_count >= 3 and unresolved_ops[1].type == 'Relu':
op_def.name = "FusedConv2D"
resolved_count = 3
elif first_op.type == 'FusedBatchNorm':
op_def = net_def.op.add()
op_def.name = first_op.name
first_op.type = 'BatchNorm'
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'epsilon'
epsilon_arg.f = first_op.get_attr('epsilon')
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NHWC'
elif first_op.type == 'Add' and first_op.name.endswith(
'batchnorm/add') and ops_count > 7:
add_op = first_op
......@@ -90,6 +111,7 @@ def convert_ops(unresolved_ops, net_def):
mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add':
raise Exception('Invalid BatchNorm Op')
get_input_tensor(mul_1_op, 0)
input_name = get_input_tensor(mul_1_op, 0).name
gamma = get_input_tensor(mul_op, 1).name
beta = get_input_tensor(sub_op, 0).name
......@@ -168,15 +190,33 @@ def convert_ops(unresolved_ops, net_def):
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = "Concat"
op_def.input.extend([input.name for input in first_op.inputs])
op_def.input.extend([first_op.inputs[i] for i in xrange(2)])
op_def.output.extend([output.name for output in first_op.outputs])
axis_arg = op_def.arg.add()
axis_arg.name = 'axis'
axis_arg.i = get_input_tensor(first_op, 2).eval().astype(np.int32)
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND',
elif first_op.type == 'ResizeBilinear':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = "ResizeBilinear"
op_def.input.extend(first_op.inputs[0])
op_def.output.extend([output.name for output in first_op.outputs])
size_arg = op_def.arg.add()
size_arg.name = 'size'
size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat)
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type in ['Relu', 'SpaceToBatchND',
'BatchToSpaceND', 'BiasAdd', 'FusedBatchNorm']:
op_def = net_def.op.add()
op_def.name = first_op.name
......@@ -204,9 +244,16 @@ def convert_to_mace_pb(input_graph_def):
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
ops_map = {}
for op in ops:
if op.name not in ops_map.keys():
ops_map[op.name] = op
else:
raise ValueError("Duplicate op names detected for ", op.name)
unresolved_ops = ops
while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, net_def)
convert_ops(ops_map, unresolved_ops, net_def)
print "Done."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册